Skip to content

Commit 2369457

Browse files
committed
implement new vector function: list_element ARROW-12669
add python test and C++ docs for the new function: list_element convert list_element into a scalar function minor changes format support scalar inputs for list_element, improve tests minor changes less generated code thanks to some template tricks less generated code, again using template tricks
1 parent 7bf5609 commit 2369457

8 files changed

Lines changed: 219 additions & 16 deletions

File tree

cpp/src/arrow/compute/kernels/codegen_internal.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& desc
6666
return result;
6767
}
6868

69+
Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
70+
const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
71+
return ValueDescr(list_type.value_type(), GetBroadcastShape(args));
72+
}
73+
6974
void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs) {
7075
for (ValueDescr& descr : *descrs) {
7176
if (descr.type->id() == Type::DICTIONARY) {

cpp/src/arrow/compute/kernels/codegen_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ static void VisitTwoArrayValuesInline(const ArrayData& arr0, const ArrayData& ar
395395
// Reusable type resolvers
396396

397397
Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs);
398+
Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args);
398399

399400
// ----------------------------------------------------------------------
400401
// Generate an array kernel given template classes

cpp/src/arrow/compute/kernels/scalar_nested.cc

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,114 @@ const FunctionDoc list_value_length_doc{
8080
"Null values emit a null in the output."),
8181
{"lists"}};
8282

83+
template <typename Type, typename IndexType>
84+
struct ListElementArray {
85+
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
86+
using ListArrayType = typename TypeTraits<Type>::ArrayType;
87+
using IndexScalarType = typename TypeTraits<IndexType>::ScalarType;
88+
const auto& index_scalar = batch[1].scalar_as<IndexScalarType>();
89+
if (ARROW_PREDICT_FALSE(!index_scalar.is_valid)) {
90+
return Status::Invalid("Index must not be null");
91+
}
92+
ListArrayType list_array(batch[0].array());
93+
auto index = index_scalar.value;
94+
if (ARROW_PREDICT_FALSE(index < 0)) {
95+
return Status::Invalid("Index ", index,
96+
" is out of bounds: should be greater than or equal to 0");
97+
}
98+
std::unique_ptr<ArrayBuilder> builder;
99+
RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), list_array.value_type(), &builder));
100+
RETURN_NOT_OK(builder->Reserve(list_array.length()));
101+
for (int i = 0; i < list_array.length(); ++i) {
102+
if (list_array.IsNull(i)) {
103+
RETURN_NOT_OK(builder->AppendNull());
104+
continue;
105+
}
106+
std::shared_ptr<arrow::Array> value_array = list_array.value_slice(i);
107+
auto len = value_array->length();
108+
if (ARROW_PREDICT_FALSE(index >= static_cast<typename IndexType::c_type>(len))) {
109+
return Status::Invalid("Index ", index, " is out of bounds: should be in [0, ",
110+
len, ")");
111+
}
112+
RETURN_NOT_OK(builder->AppendArraySlice(*value_array->data(), index, 1));
113+
}
114+
ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish());
115+
out->value = result->data();
116+
return Status::OK();
117+
}
118+
};
119+
120+
template <typename, typename IndexType>
121+
struct ListElementScalar {
122+
static Status Exec(KernelContext* /*ctx*/, const ExecBatch& batch, Datum* out) {
123+
using IndexScalarType = typename TypeTraits<IndexType>::ScalarType;
124+
const auto& index_scalar = batch[1].scalar_as<IndexScalarType>();
125+
if (ARROW_PREDICT_FALSE(!index_scalar.is_valid)) {
126+
return Status::Invalid("Index must not be null");
127+
}
128+
const auto& list_scalar = batch[0].scalar_as<BaseListScalar>();
129+
if (ARROW_PREDICT_FALSE(!list_scalar.is_valid)) {
130+
out->value = MakeNullScalar(
131+
checked_cast<const BaseListType&>(*batch[0].type()).value_type());
132+
return Status::OK();
133+
}
134+
auto list = list_scalar.value;
135+
auto index = index_scalar.value;
136+
auto len = list->length();
137+
if (ARROW_PREDICT_FALSE(index < 0 ||
138+
index >= static_cast<typename IndexType::c_type>(len))) {
139+
return Status::Invalid("Index ", index, " is out of bounds: should be in [0, ", len,
140+
")");
141+
}
142+
ARROW_ASSIGN_OR_RAISE(out->value, list->GetScalar(index));
143+
return Status::OK();
144+
}
145+
};
146+
147+
template <typename InListType>
148+
void AddListElementArrayKernels(ScalarFunction* func) {
149+
for (const auto& index_type : IntTypes()) {
150+
auto inputs = {InputType::Array(InListType::type_id), InputType::Scalar(index_type)};
151+
auto output = OutputType{ListValuesType};
152+
auto sig = KernelSignature::Make(std::move(inputs), std::move(output),
153+
/*is_varargs=*/false);
154+
auto scalar_exec = GenerateInteger<ListElementArray, InListType>({index_type->id()});
155+
ScalarKernel kernel{std::move(sig), std::move(scalar_exec)};
156+
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
157+
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
158+
DCHECK_OK(func->AddKernel(std::move(kernel)));
159+
}
160+
}
161+
162+
void AddListElementArrayKernels(ScalarFunction* func) {
163+
AddListElementArrayKernels<ListType>(func);
164+
AddListElementArrayKernels<LargeListType>(func);
165+
AddListElementArrayKernels<FixedSizeListType>(func);
166+
}
167+
168+
void AddListElementScalarKernels(ScalarFunction* func) {
169+
for (const auto list_type_id : {Type::LIST, Type::LARGE_LIST, Type::FIXED_SIZE_LIST}) {
170+
for (const auto& index_type : IntTypes()) {
171+
auto inputs = {InputType::Scalar(list_type_id), InputType::Scalar(index_type)};
172+
auto output = OutputType{ListValuesType};
173+
auto sig = KernelSignature::Make(std::move(inputs), std::move(output),
174+
/*is_varargs=*/false);
175+
auto scalar_exec = GenerateInteger<ListElementScalar, void>({index_type->id()});
176+
ScalarKernel kernel{std::move(sig), std::move(scalar_exec)};
177+
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
178+
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
179+
DCHECK_OK(func->AddKernel(std::move(kernel)));
180+
}
181+
}
182+
}
183+
184+
const FunctionDoc list_element_doc(
185+
"Compute elements using of nested list values using an index",
186+
("`lists` must have a list-like type.\n"
187+
"For each value in each list of `lists`, the element at `index`\n"
188+
"is emitted. Null values emit a null in the output."),
189+
{"lists", "index"});
190+
83191
Result<ValueDescr> MakeStructResolve(KernelContext* ctx,
84192
const std::vector<ValueDescr>& descrs) {
85193
auto names = OptionsWrapper<MakeStructOptions>::Get(ctx).field_names;
@@ -185,6 +293,12 @@ void RegisterScalarNested(FunctionRegistry* registry) {
185293
ListValueLength<LargeListType>));
186294
DCHECK_OK(registry->AddFunction(std::move(list_value_length)));
187295

296+
auto list_element = std::make_shared<ScalarFunction>("list_element", Arity::Binary(),
297+
&list_element_doc);
298+
AddListElementArrayKernels(list_element.get());
299+
AddListElementScalarKernels(list_element.get());
300+
DCHECK_OK(registry->AddFunction(std::move(list_element)));
301+
188302
static MakeStructOptions kDefaultMakeStructOptions;
189303
auto make_struct_function = std::make_shared<ScalarFunction>(
190304
"make_struct", Arity::VarArgs(), &make_struct_doc, &kDefaultMakeStructOptions);

cpp/src/arrow/compute/kernels/scalar_nested_test.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,70 @@ TEST(TestScalarNested, ListValueLength) {
4343
"[3, null, 3, 3]");
4444
}
4545

46+
TEST(TestScalarNested, ListElementNonFixedListWithNulls) {
47+
auto sample = "[[7, 5, 81], [6, null, 4, 7, 8], [3, 12, 2, 0], [1, 9], null]";
48+
for (auto ty : NumericTypes()) {
49+
for (auto list_type : {list(ty), large_list(ty)}) {
50+
auto input = ArrayFromJSON(list_type, sample);
51+
auto null_input = ArrayFromJSON(list_type, "[null]");
52+
for (auto index_type : IntTypes()) {
53+
auto index = ScalarFromJSON(index_type, "1");
54+
auto expected = ArrayFromJSON(ty, "[5, null, 12, 9, null]");
55+
auto expected_null = ArrayFromJSON(ty, "[null]");
56+
CheckScalar("list_element", {input, index}, expected);
57+
CheckScalar("list_element", {null_input, index}, expected_null);
58+
}
59+
}
60+
}
61+
}
62+
63+
TEST(TestScalarNested, ListElementFixedList) {
64+
auto sample = "[[7, 5, 81], [6, 4, 8], [3, 12, 2], [1, 43, 87]]";
65+
for (auto ty : NumericTypes()) {
66+
auto input = ArrayFromJSON(fixed_size_list(ty, 3), sample);
67+
for (auto index_type : IntTypes()) {
68+
auto index = ScalarFromJSON(index_type, "0");
69+
auto expected = ArrayFromJSON(ty, "[7, 6, 3, 1]");
70+
CheckScalar("list_element", {input, index}, expected);
71+
}
72+
}
73+
}
74+
75+
TEST(TestScalarNested, ListElementInvalid) {
76+
auto input_array = ArrayFromJSON(list(float32()), "[[0.1, 1.1], [0.2, 1.2]]");
77+
auto input_scalar = ScalarFromJSON(list(float32()), "[0.1, 0.2]");
78+
79+
// invalid index: null
80+
auto index = ScalarFromJSON(int32(), "null");
81+
EXPECT_THAT(CallFunction("list_element", {input_array, index}),
82+
Raises(StatusCode::Invalid));
83+
EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
84+
Raises(StatusCode::Invalid));
85+
86+
// invalid index: < 0
87+
index = ScalarFromJSON(int32(), "-1");
88+
EXPECT_THAT(CallFunction("list_element", {input_array, index}),
89+
Raises(StatusCode::Invalid));
90+
EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
91+
Raises(StatusCode::Invalid));
92+
93+
// invalid index: >= list.length
94+
index = ScalarFromJSON(int32(), "2");
95+
EXPECT_THAT(CallFunction("list_element", {input_array, index}),
96+
Raises(StatusCode::Invalid));
97+
EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
98+
Raises(StatusCode::Invalid));
99+
100+
// invalid input
101+
input_array = ArrayFromJSON(list(float32()), "[[41, 6, 93], [], [2]]");
102+
input_scalar = ScalarFromJSON(list(float32()), "[]");
103+
index = ScalarFromJSON(int32(), "0");
104+
EXPECT_THAT(CallFunction("list_element", {input_array, index}),
105+
Raises(StatusCode::Invalid));
106+
EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
107+
Raises(StatusCode::Invalid));
108+
}
109+
46110
struct {
47111
Result<Datum> operator()(std::vector<Datum> args) {
48112
return CallFunction("make_struct", args);

cpp/src/arrow/compute/kernels/vector_nested.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,6 @@ struct ListParentIndicesArray {
110110
}
111111
};
112112

113-
Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
114-
const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
115-
return ValueDescr::Array(list_type.value_type());
116-
}
117-
118113
Result<std::shared_ptr<DataType>> ListParentIndicesType(const DataType& input_type) {
119114
switch (input_type.id()) {
120115
case Type::LIST:

docs/source/cpp/compute.rst

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,19 +1492,24 @@ value, but smaller than nulls.
14921492
Structural transforms
14931493
~~~~~~~~~~~~~~~~~~~~~
14941494

1495-
+--------------------------+------------+--------------------+---------------------+---------+
1496-
| Function name | Arity | Input types | Output type | Notes |
1497-
+==========================+============+====================+=====================+=========+
1498-
| list_flatten | Unary | List-like | List value type | \(1) |
1499-
+--------------------------+------------+--------------------+---------------------+---------+
1500-
| list_parent_indices | Unary | List-like | Int32 or Int64 | \(2) |
1501-
+--------------------------+------------+--------------------+---------------------+---------+
1502-
1503-
* \(1) The top level of nesting is removed: all values in the list child array,
1495+
+--------------------------+------------+------------------------------------+---------------------+---------+
1496+
| Function name | Arity | Input types | Output type | Notes |
1497+
+==========================+============+====================================+=====================+=========+
1498+
| list_element | Binary | List-like (Arg 0), Integral (Arg 1)| List value type | \(1) |
1499+
+--------------------------+------------+------------------------------------+---------------------+---------+
1500+
| list_flatten | Unary | List-like | List value type | \(2) |
1501+
+--------------------------+------------+------------------------------------+---------------------+---------+
1502+
| list_parent_indices | Unary | List-like | Int32 or Int64 | \(3) |
1503+
+--------------------------+------------+------------------------------------+---------------------+---------+
1504+
1505+
* \(1) Output is an array of the same length as the input list array. The
1506+
output values are the values at the specified index of each child list.
1507+
1508+
* \(2) The top level of nesting is removed: all values in the list child array,
15041509
including nulls, are appended to the output. However, nulls in the parent
15051510
list array are discarded.
15061511

1507-
* \(2) For each value in the list child array, the index at which it is found
1512+
* \(3) For each value in the list child array, the index at which it is found
15081513
in the list array is appended to the output. Nulls in the parent list array
15091514
are discarded. Output type is Int32 for List and FixedSizeList, Int64 for
15101515
LargeList.

docs/source/python/api/compute.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ Structural Transforms
371371
is_nan
372372
is_null
373373
is_valid
374-
list_value_length
374+
list_element
375375
list_flatten
376376
list_parent_indices
377+
list_value_length

python/pyarrow/tests/test_compute.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,3 +2133,21 @@ def test_case_when():
21332133
[False, True, None]),
21342134
[1, 2, 3],
21352135
[11, 12, 13]) == pa.array([1, 12, None])
2136+
2137+
2138+
def test_list_element():
2139+
element_type = pa.struct([('a', pa.float64()), ('b', pa.int8())])
2140+
list_type = pa.list_(element_type)
2141+
l1 = [{'a': .4, 'b': 2}, None, {'a': .2, 'b': 4}, None, {'a': 5.6, 'b': 6}]
2142+
l2 = [None, {'a': .52, 'b': 3}, {'a': .7, 'b': 4}, None, {'a': .6, 'b': 8}]
2143+
lists = pa.array([l1, l2], list_type)
2144+
2145+
index = 1
2146+
result = pa.compute.list_element(lists, index)
2147+
expected = pa.array([None, {'a': 0.52, 'b': 3}], element_type)
2148+
assert result.equals(expected)
2149+
2150+
index = 4
2151+
result = pa.compute.list_element(lists, index)
2152+
expected = pa.array([{'a': 5.6, 'b': 6}, {'a': .6, 'b': 8}], element_type)
2153+
assert result.equals(expected)

0 commit comments

Comments
 (0)