diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 466b3d5dd4a7..8f577539e34a 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -156,6 +156,9 @@ Result ExecBatch::Make(std::vector values) { Result> ExecBatch::ToRecordBatch( std::shared_ptr schema, MemoryPool* pool) const { + if (static_cast(schema->num_fields()) > values.size()) { + return Status::Invalid("ExecBatch::ToTRecordBatch mismatching schema size"); + } ArrayVector columns(schema->num_fields()); for (size_t i = 0; i < columns.size(); ++i) { @@ -163,8 +166,13 @@ Result> ExecBatch::ToRecordBatch( if (value.is_array()) { columns[i] = value.make_array(); continue; + } else if (value.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(columns[i], + MakeArrayFromScalar(*value.scalar(), length, pool)); + } else { + return Status::TypeError("ExecBatch::ToRecordBatch value ", i, " with unsupported ", + "value kind ", ::arrow::ToString(value.kind())); } - ARROW_ASSIGN_OR_RAISE(columns[i], MakeArrayFromScalar(*value.scalar(), length, pool)); } return RecordBatch::Make(std::move(schema), length, std::move(columns)); diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index eac18f194d25..7f29a673d935 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -35,6 +35,7 @@ #include "arrow/compute/kernel.h" #include "arrow/compute/registry.h" #include "arrow/memory_pool.h" +#include "arrow/record_batch.h" #include "arrow/scalar.h" #include "arrow/status.h" #include "arrow/type.h" @@ -55,6 +56,44 @@ using ::arrow::internal::BitmapEquals; using ::arrow::internal::CopyBitmap; using ::arrow::internal::CountSetBits; +TEST(ExecBatch, ToRecordBatch) { + auto i32_array = ArrayFromJSON(int32(), "[0, 1, 2]"); + auto utf8_array = ArrayFromJSON(utf8(), R"(["a", "b", "c"])"); + ExecBatch exec_batch({Datum(i32_array), Datum(utf8_array)}, 3); + + auto right_schema = schema({field("a", int32()), field("b", utf8())}); + ASSERT_OK_AND_ASSIGN(auto right_record_batch, exec_batch.ToRecordBatch(right_schema)); + ASSERT_OK(right_record_batch->ValidateFull()); + auto expected_batch = RecordBatchFromJSON(right_schema, R"([ + {"a": 0, "b": "a"}, + {"a": 1, "b": "b"}, + {"a": 2, "b": "c"} + ])"); + AssertBatchesEqual(*right_record_batch, *expected_batch); + + // With a scalar column + auto utf8_scalar = ScalarFromJSON(utf8(), R"("z")"); + exec_batch = ExecBatch({Datum(i32_array), Datum(utf8_scalar)}, 3); + ASSERT_OK_AND_ASSIGN(right_record_batch, exec_batch.ToRecordBatch(right_schema)); + ASSERT_OK(right_record_batch->ValidateFull()); + expected_batch = RecordBatchFromJSON(right_schema, R"([ + {"a": 0, "b": "z"}, + {"a": 1, "b": "z"}, + {"a": 2, "b": "z"} + ])"); + AssertBatchesEqual(*right_record_batch, *expected_batch); + + // Wrong number of fields in schema + auto reject_schema = + schema({field("a", int32()), field("b", utf8()), field("c", float64())}); + ASSERT_RAISES(Invalid, exec_batch.ToRecordBatch(reject_schema)); + + // Wrong-kind exec batch (not really valid, but test it here anyway) + ExecBatch miskinded_batch({Datum()}, 0); + auto null_schema = schema({field("a", null())}); + ASSERT_RAISES(TypeError, miskinded_batch.ToRecordBatch(null_schema)); +} + TEST(ExecContext, BasicWorkings) { { ExecContext ctx;