2222#include < functional>
2323#include < limits>
2424#include < memory>
25+ #include < string>
2526#include < type_traits>
2627#include < utility>
2728#include < vector>
@@ -1271,6 +1272,62 @@ class ZeroCopyCast : public CastKernelBase {
12711272 }
12721273};
12731274
1275+ class ExtensionCastKernel : public CastKernelBase {
1276+ public:
1277+ static Status Make (const DataType& in_type, std::shared_ptr<DataType> out_type,
1278+ const CastOptions& options,
1279+ std::unique_ptr<CastKernelBase>* kernel) {
1280+ const auto storage_type = checked_cast<const ExtensionType&>(in_type).storage_type ();
1281+
1282+ std::unique_ptr<UnaryKernel> storage_caster;
1283+ RETURN_NOT_OK (GetCastFunction (*storage_type, out_type, options, &storage_caster));
1284+ kernel->reset (
1285+ new ExtensionCastKernel (std::move (storage_caster), std::move (out_type)));
1286+
1287+ return Status::OK ();
1288+ }
1289+
1290+ Status Init (const DataType& in_type) override {
1291+ auto & type = checked_cast<const ExtensionType&>(in_type);
1292+ storage_type_ = type.storage_type ();
1293+ extension_name_ = type.extension_name ();
1294+ return Status::OK ();
1295+ }
1296+
1297+ Status Call (FunctionContext* ctx, const Datum& input, Datum* out) override {
1298+ DCHECK_EQ (input.kind (), Datum::ARRAY);
1299+
1300+ // validate: type is the same as the type the kernel was constructed with
1301+ const auto & input_type = checked_cast<const ExtensionType&>(*input.type ());
1302+ if (input_type.extension_name () != extension_name_) {
1303+ return Status::TypeError (
1304+ " The cast kernel was constructed to cast from the extension type named '" ,
1305+ extension_name_, " ' but input has extension type named '" ,
1306+ input_type.extension_name (), " '" );
1307+ }
1308+ if (!input_type.storage_type ()->Equals (storage_type_)) {
1309+ return Status::TypeError (" The cast kernel was constructed with a storage type: " ,
1310+ storage_type_->ToString (),
1311+ " , but it is called with a different storage type:" ,
1312+ input_type.storage_type ()->ToString ());
1313+ }
1314+
1315+ // construct an ArrayData object with the underlying storage type
1316+ auto new_input = input.array ()->Copy ();
1317+ new_input->type = storage_type_;
1318+ return InvokeWithAllocation (ctx, storage_caster_.get (), new_input, out);
1319+ }
1320+
1321+ protected:
1322+ ExtensionCastKernel (std::unique_ptr<UnaryKernel> storage_caster,
1323+ std::shared_ptr<DataType> out_type)
1324+ : CastKernelBase(std::move(out_type)), storage_caster_(std::move(storage_caster)) {}
1325+
1326+ std::string extension_name_;
1327+ std::shared_ptr<DataType> storage_type_;
1328+ std::unique_ptr<UnaryKernel> storage_caster_;
1329+ };
1330+
12741331class CastKernel : public CastKernelBase {
12751332 public:
12761333 CastKernel (const CastOptions& options, const CastFunction& func,
@@ -1420,11 +1477,6 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
14201477 return Status::OK ();
14211478 }
14221479
1423- if (in_type.id () == Type::NA) {
1424- kernel->reset (new FromNullCastKernel (std::move (out_type)));
1425- return Status::OK ();
1426- }
1427-
14281480 std::unique_ptr<CastKernelBase> cast_kernel;
14291481 switch (in_type.id ()) {
14301482 CAST_FUNCTION_CASE (BooleanType);
@@ -1450,6 +1502,9 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
14501502 CAST_FUNCTION_CASE (LargeBinaryType);
14511503 CAST_FUNCTION_CASE (LargeStringType);
14521504 CAST_FUNCTION_CASE (DictionaryType);
1505+ case Type::NA:
1506+ cast_kernel.reset (new FromNullCastKernel (std::move (out_type)));
1507+ break ;
14531508 case Type::LIST:
14541509 RETURN_NOT_OK (
14551510 GetListCastFunc<ListType>(in_type, std::move (out_type), options, &cast_kernel));
@@ -1458,6 +1513,10 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
14581513 RETURN_NOT_OK (GetListCastFunc<LargeListType>(in_type, std::move (out_type), options,
14591514 &cast_kernel));
14601515 break ;
1516+ case Type::EXTENSION:
1517+ RETURN_NOT_OK (ExtensionCastKernel::Make (std::move (in_type), std::move (out_type),
1518+ options, &cast_kernel));
1519+ break ;
14611520 default :
14621521 break ;
14631522 }
0 commit comments