Skip to content

Commit fdf8fdc

Browse files
kszucswesm
authored andcommitted
ARROW-7858: [C++][Python] Support casting from ExtensionArray
This add support for casting from ExtensionType. We should probably add support for converting to an ExtensionType, but we can handle it in a follow-up. Closes #6633 from kszucs/ext-cast Authored-by: Krisztián Szűcs <szucs.krisztian@gmail.com> Signed-off-by: Wes McKinney <wesm+git@apache.org>
1 parent c15637d commit fdf8fdc

5 files changed

Lines changed: 223 additions & 10 deletions

File tree

cpp/src/arrow/compute/kernels/cast.cc

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <functional>
2323
#include <limits>
2424
#include <memory>
25+
#include <string>
2526
#include <type_traits>
2627
#include <utility>
2728
#include <vector>
@@ -1271,6 +1272,62 @@ class ZeroCopyCast : public CastKernelBase {
12711272
}
12721273
};
12731274

1275+
class ExtensionCastKernel : public CastKernelBase {
1276+
public:
1277+
static Status Make(const DataType& in_type, std::shared_ptr<DataType> out_type,
1278+
const CastOptions& options,
1279+
std::unique_ptr<CastKernelBase>* kernel) {
1280+
const auto storage_type = checked_cast<const ExtensionType&>(in_type).storage_type();
1281+
1282+
std::unique_ptr<UnaryKernel> storage_caster;
1283+
RETURN_NOT_OK(GetCastFunction(*storage_type, out_type, options, &storage_caster));
1284+
kernel->reset(
1285+
new ExtensionCastKernel(std::move(storage_caster), std::move(out_type)));
1286+
1287+
return Status::OK();
1288+
}
1289+
1290+
Status Init(const DataType& in_type) override {
1291+
auto& type = checked_cast<const ExtensionType&>(in_type);
1292+
storage_type_ = type.storage_type();
1293+
extension_name_ = type.extension_name();
1294+
return Status::OK();
1295+
}
1296+
1297+
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override {
1298+
DCHECK_EQ(input.kind(), Datum::ARRAY);
1299+
1300+
// validate: type is the same as the type the kernel was constructed with
1301+
const auto& input_type = checked_cast<const ExtensionType&>(*input.type());
1302+
if (input_type.extension_name() != extension_name_) {
1303+
return Status::TypeError(
1304+
"The cast kernel was constructed to cast from the extension type named '",
1305+
extension_name_, "' but input has extension type named '",
1306+
input_type.extension_name(), "'");
1307+
}
1308+
if (!input_type.storage_type()->Equals(storage_type_)) {
1309+
return Status::TypeError("The cast kernel was constructed with a storage type: ",
1310+
storage_type_->ToString(),
1311+
", but it is called with a different storage type:",
1312+
input_type.storage_type()->ToString());
1313+
}
1314+
1315+
// construct an ArrayData object with the underlying storage type
1316+
auto new_input = input.array()->Copy();
1317+
new_input->type = storage_type_;
1318+
return InvokeWithAllocation(ctx, storage_caster_.get(), new_input, out);
1319+
}
1320+
1321+
protected:
1322+
ExtensionCastKernel(std::unique_ptr<UnaryKernel> storage_caster,
1323+
std::shared_ptr<DataType> out_type)
1324+
: CastKernelBase(std::move(out_type)), storage_caster_(std::move(storage_caster)) {}
1325+
1326+
std::string extension_name_;
1327+
std::shared_ptr<DataType> storage_type_;
1328+
std::unique_ptr<UnaryKernel> storage_caster_;
1329+
};
1330+
12741331
class CastKernel : public CastKernelBase {
12751332
public:
12761333
CastKernel(const CastOptions& options, const CastFunction& func,
@@ -1420,11 +1477,6 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
14201477
return Status::OK();
14211478
}
14221479

1423-
if (in_type.id() == Type::NA) {
1424-
kernel->reset(new FromNullCastKernel(std::move(out_type)));
1425-
return Status::OK();
1426-
}
1427-
14281480
std::unique_ptr<CastKernelBase> cast_kernel;
14291481
switch (in_type.id()) {
14301482
CAST_FUNCTION_CASE(BooleanType);
@@ -1450,6 +1502,9 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
14501502
CAST_FUNCTION_CASE(LargeBinaryType);
14511503
CAST_FUNCTION_CASE(LargeStringType);
14521504
CAST_FUNCTION_CASE(DictionaryType);
1505+
case Type::NA:
1506+
cast_kernel.reset(new FromNullCastKernel(std::move(out_type)));
1507+
break;
14531508
case Type::LIST:
14541509
RETURN_NOT_OK(
14551510
GetListCastFunc<ListType>(in_type, std::move(out_type), options, &cast_kernel));
@@ -1458,6 +1513,10 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
14581513
RETURN_NOT_OK(GetListCastFunc<LargeListType>(in_type, std::move(out_type), options,
14591514
&cast_kernel));
14601515
break;
1516+
case Type::EXTENSION:
1517+
RETURN_NOT_OK(ExtensionCastKernel::Make(std::move(in_type), std::move(out_type),
1518+
options, &cast_kernel));
1519+
break;
14611520
default:
14621521
break;
14631522
}

cpp/src/arrow/compute/kernels/cast_test.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626

2727
#include "arrow/array.h"
2828
#include "arrow/buffer.h"
29+
#include "arrow/extension_type.h"
2930
#include "arrow/memory_pool.h"
3031
#include "arrow/status.h"
3132
#include "arrow/table.h"
33+
#include "arrow/testing/extension_type.h"
3234
#include "arrow/testing/gtest_common.h"
3335
#include "arrow/testing/gtest_util.h"
3436
#include "arrow/testing/random.h"
@@ -1645,5 +1647,54 @@ TYPED_TEST(TestDictionaryCast, OutTypeError) {
16451647
this->CheckPass(*plain_array, *dict_array, dict_array->type(), options);
16461648
}*/
16471649

1650+
std::shared_ptr<Array> SmallintArrayFromJSON(const std::string& json_data) {
1651+
auto arr = ArrayFromJSON(int16(), json_data);
1652+
auto ext_data = arr->data()->Copy();
1653+
ext_data->type = smallint();
1654+
return MakeArray(ext_data);
1655+
}
1656+
1657+
TEST_F(TestCast, ExtensionTypeToIntDowncast) {
1658+
auto smallint = std::make_shared<SmallintType>();
1659+
ASSERT_OK(RegisterExtensionType(smallint));
1660+
1661+
CastOptions options;
1662+
options.allow_int_overflow = false;
1663+
1664+
std::shared_ptr<Array> result;
1665+
std::vector<bool> is_valid = {true, false, true, true, true};
1666+
1667+
// Smallint(int16) to int16
1668+
auto v0 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]");
1669+
CheckZeroCopy(*v0, int16());
1670+
1671+
// Smallint(int16) to uint8, no overflow/underrun
1672+
auto v1 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]");
1673+
auto e1 = ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]");
1674+
CheckPass(*v1, *e1, uint8(), options);
1675+
1676+
// Smallint(int16) to uint8, with overflow
1677+
auto v2 = SmallintArrayFromJSON("[0, null, 256, 1, 3]");
1678+
auto e2 = ArrayFromJSON(uint8(), "[0, null, 0, 1, 3]");
1679+
// allow overflow
1680+
options.allow_int_overflow = true;
1681+
CheckPass(*v2, *e2, uint8(), options);
1682+
// disallow overflow
1683+
options.allow_int_overflow = false;
1684+
ASSERT_RAISES(Invalid, Cast(&ctx_, *v2, uint8(), options, &result));
1685+
1686+
// Smallint(int16) to uint8, with underflow
1687+
auto v3 = SmallintArrayFromJSON("[0, null, -1, 1, 0]");
1688+
auto e3 = ArrayFromJSON(uint8(), "[0, null, 255, 1, 0]");
1689+
// allow overflow
1690+
options.allow_int_overflow = true;
1691+
CheckPass(*v3, *e3, uint8(), options);
1692+
// disallow overflow
1693+
options.allow_int_overflow = false;
1694+
ASSERT_RAISES(Invalid, Cast(&ctx_, *v3, uint8(), options, &result));
1695+
1696+
ASSERT_OK(UnregisterExtensionType("smallint"));
1697+
}
1698+
16481699
} // namespace compute
16491700
} // namespace arrow

cpp/src/arrow/testing/extension_type.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,38 @@ class ARROW_EXPORT UUIDType : public ExtensionType {
4848
std::string Serialize() const override { return "uuid-type-unique-code"; }
4949
};
5050

51+
class ARROW_EXPORT SmallintArray : public ExtensionArray {
52+
public:
53+
using ExtensionArray::ExtensionArray;
54+
};
55+
56+
class ARROW_EXPORT SmallintType : public ExtensionType {
57+
public:
58+
SmallintType() : ExtensionType(int16()) {}
59+
60+
std::string extension_name() const override { return "smallint"; }
61+
62+
bool ExtensionEquals(const ExtensionType& other) const override;
63+
64+
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
65+
66+
Status Deserialize(std::shared_ptr<DataType> storage_type,
67+
const std::string& serialized,
68+
std::shared_ptr<DataType>* out) const override;
69+
70+
std::string Serialize() const override { return "smallint"; }
71+
};
72+
5173
ARROW_EXPORT
5274
std::shared_ptr<DataType> uuid();
5375

76+
ARROW_EXPORT
77+
std::shared_ptr<DataType> smallint();
78+
5479
ARROW_EXPORT
5580
std::shared_ptr<Array> ExampleUUID();
5681

82+
ARROW_EXPORT
83+
std::shared_ptr<Array> ExampleSmallint();
84+
5785
} // namespace arrow

cpp/src/arrow/testing/gtest_util.cc

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,11 +382,7 @@ void SleepFor(double seconds) {
382382
// Extension types
383383

384384
bool UUIDType::ExtensionEquals(const ExtensionType& other) const {
385-
const auto& other_ext = static_cast<const ExtensionType&>(other);
386-
if (other_ext.extension_name() != this->extension_name()) {
387-
return false;
388-
}
389-
return true;
385+
return (other.extension_name() == this->extension_name());
390386
}
391387

392388
std::shared_ptr<Array> UUIDType::MakeArray(std::shared_ptr<ArrayData> data) const {
@@ -423,4 +419,38 @@ std::shared_ptr<Array> ExampleUUID() {
423419
return MakeArray(ext_data);
424420
}
425421

422+
bool SmallintType::ExtensionEquals(const ExtensionType& other) const {
423+
return (other.extension_name() == this->extension_name());
424+
}
425+
426+
std::shared_ptr<Array> SmallintType::MakeArray(std::shared_ptr<ArrayData> data) const {
427+
DCHECK_EQ(data->type->id(), Type::EXTENSION);
428+
DCHECK_EQ("smallint", static_cast<const ExtensionType&>(*data->type).extension_name());
429+
return std::make_shared<SmallintArray>(data);
430+
}
431+
432+
Status SmallintType::Deserialize(std::shared_ptr<DataType> storage_type,
433+
const std::string& serialized,
434+
std::shared_ptr<DataType>* out) const {
435+
if (serialized != "smallint") {
436+
return Status::Invalid("Type identifier did not match");
437+
}
438+
if (!storage_type->Equals(*int16())) {
439+
return Status::Invalid("Invalid storage type for SmallintType");
440+
}
441+
*out = std::make_shared<SmallintType>();
442+
return Status::OK();
443+
}
444+
445+
std::shared_ptr<DataType> smallint() { return std::make_shared<SmallintType>(); }
446+
447+
std::shared_ptr<Array> ExampleSmallint() {
448+
auto storage_type = int16();
449+
auto ext_type = smallint();
450+
auto arr = ArrayFromJSON(storage_type, "[-32768, null, 1, 2, 3, 4, 32767]");
451+
auto ext_data = arr->data()->Copy();
452+
ext_data->type = ext_type;
453+
return MakeArray(ext_data);
454+
}
455+
426456
} // namespace arrow

python/pyarrow/tests/test_extension_type.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424
import pytest
2525

2626

27+
class IntegerType(pa.PyExtensionType):
28+
29+
def __init__(self):
30+
pa.PyExtensionType.__init__(self, pa.int64())
31+
32+
def __reduce__(self):
33+
return IntegerType, ()
34+
35+
2736
class UuidType(pa.PyExtensionType):
2837

2938
def __init__(self):
@@ -168,6 +177,42 @@ def test_ext_array_pickling():
168177
assert arr.storage.to_pylist() == [b"foo", b"bar"]
169178

170179

180+
def test_cast_kernel_on_extension_arrays():
181+
# test array casting
182+
storage = pa.array([1, 2, 3, 4], pa.int64())
183+
arr = pa.ExtensionArray.from_storage(IntegerType(), storage)
184+
185+
# test that no allocation happens during identity cast
186+
allocated_before_cast = pa.total_allocated_bytes()
187+
casted = arr.cast(pa.int64())
188+
assert pa.total_allocated_bytes() == allocated_before_cast
189+
190+
cases = [
191+
(pa.int64(), pa.Int64Array),
192+
(pa.int32(), pa.Int32Array),
193+
(pa.int16(), pa.Int16Array),
194+
(pa.uint64(), pa.UInt64Array),
195+
(pa.uint32(), pa.UInt32Array),
196+
(pa.uint16(), pa.UInt16Array)
197+
]
198+
for typ, klass in cases:
199+
casted = arr.cast(typ)
200+
assert casted.type == typ
201+
assert isinstance(casted, klass)
202+
203+
# test chunked array casting
204+
arr = pa.chunked_array([arr, arr])
205+
casted = arr.cast(pa.int16())
206+
assert casted.type == pa.int16()
207+
assert isinstance(casted, pa.ChunkedArray)
208+
209+
210+
def test_casting_to_extension_type_raises():
211+
arr = pa.array([1, 2, 3, 4], pa.int64())
212+
with pytest.raises(pa.ArrowNotImplementedError):
213+
arr.cast(IntegerType())
214+
215+
171216
def example_batch():
172217
ty = ParamExtType(3)
173218
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))

0 commit comments

Comments
 (0)