Skip to content

Commit 638338f

Browse files
lidavidmnealrichardson
authored andcommitted
ARROW-13298: [C++] Implement any/all hash aggregate kernels
Closes #10791 from lidavidm/arrow-13298 Authored-by: David Li <li.davidm96@gmail.com> Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
1 parent 5f5b803 commit 638338f

2 files changed

Lines changed: 188 additions & 0 deletions

File tree

cpp/src/arrow/compute/kernels/hash_aggregate.cc

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,123 @@ struct GroupedMinMaxFactory {
11491149
InputType argument_type;
11501150
};
11511151

1152+
// ----------------------------------------------------------------------
1153+
// Any/All implementation
1154+
1155+
struct GroupedAnyImpl : public GroupedAggregator {
1156+
Status Init(ExecContext* ctx, const FunctionOptions*) override {
1157+
seen_ = TypedBufferBuilder<bool>(ctx->memory_pool());
1158+
return Status::OK();
1159+
}
1160+
1161+
Status Resize(int64_t new_num_groups) override {
1162+
auto added_groups = new_num_groups - num_groups_;
1163+
num_groups_ = new_num_groups;
1164+
return seen_.Append(added_groups, false);
1165+
}
1166+
1167+
Status Merge(GroupedAggregator&& raw_other,
1168+
const ArrayData& group_id_mapping) override {
1169+
auto other = checked_cast<GroupedAnyImpl*>(&raw_other);
1170+
1171+
auto seen = seen_.mutable_data();
1172+
auto other_seen = other->seen_.data();
1173+
1174+
auto g = group_id_mapping.GetValues<uint32_t>(1);
1175+
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
1176+
if (BitUtil::GetBit(other_seen, other_g)) BitUtil::SetBitTo(seen, *g, true);
1177+
}
1178+
return Status::OK();
1179+
}
1180+
1181+
Status Consume(const ExecBatch& batch) override {
1182+
auto seen = seen_.mutable_data();
1183+
1184+
const auto& input = *batch[0].array();
1185+
1186+
auto g = batch[1].array()->GetValues<uint32_t>(1);
1187+
arrow::internal::VisitTwoBitBlocksVoid(
1188+
input.buffers[0], input.offset, input.buffers[1], input.offset, input.length,
1189+
[&](int64_t) { BitUtil::SetBitTo(seen, *g++, true); }, [&]() { g++; });
1190+
return Status::OK();
1191+
}
1192+
1193+
Result<Datum> Finalize() override {
1194+
ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
1195+
return std::make_shared<BooleanArray>(num_groups_, std::move(seen));
1196+
}
1197+
1198+
std::shared_ptr<DataType> out_type() const override { return boolean(); }
1199+
1200+
int64_t num_groups_ = 0;
1201+
ScalarAggregateOptions options_;
1202+
TypedBufferBuilder<bool> seen_;
1203+
};
1204+
1205+
struct GroupedAllImpl : public GroupedAggregator {
1206+
Status Init(ExecContext* ctx, const FunctionOptions*) override {
1207+
seen_ = TypedBufferBuilder<bool>(ctx->memory_pool());
1208+
return Status::OK();
1209+
}
1210+
1211+
Status Resize(int64_t new_num_groups) override {
1212+
auto added_groups = new_num_groups - num_groups_;
1213+
num_groups_ = new_num_groups;
1214+
return seen_.Append(added_groups, true);
1215+
}
1216+
1217+
Status Merge(GroupedAggregator&& raw_other,
1218+
const ArrayData& group_id_mapping) override {
1219+
auto other = checked_cast<GroupedAllImpl*>(&raw_other);
1220+
1221+
auto seen = seen_.mutable_data();
1222+
auto other_seen = other->seen_.data();
1223+
1224+
auto g = group_id_mapping.GetValues<uint32_t>(1);
1225+
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
1226+
BitUtil::SetBitTo(
1227+
seen, *g, BitUtil::GetBit(seen, *g) && BitUtil::GetBit(other_seen, other_g));
1228+
}
1229+
return Status::OK();
1230+
}
1231+
1232+
Status Consume(const ExecBatch& batch) override {
1233+
auto seen = seen_.mutable_data();
1234+
1235+
const auto& input = *batch[0].array();
1236+
1237+
auto g = batch[1].array()->GetValues<uint32_t>(1);
1238+
if (input.MayHaveNulls()) {
1239+
const uint8_t* bitmap = input.buffers[1]->data();
1240+
arrow::internal::VisitBitBlocksVoid(
1241+
input.buffers[0], input.offset, input.length,
1242+
[&](int64_t position) {
1243+
BitUtil::SetBitTo(seen, *g,
1244+
BitUtil::GetBit(seen, *g) &&
1245+
BitUtil::GetBit(bitmap, input.offset + position));
1246+
g++;
1247+
},
1248+
[&]() { g++; });
1249+
} else {
1250+
arrow::internal::VisitBitBlocksVoid(
1251+
input.buffers[1], input.offset, input.length, [&](int64_t) { g++; },
1252+
[&]() { BitUtil::SetBitTo(seen, *g++, false); });
1253+
}
1254+
return Status::OK();
1255+
}
1256+
1257+
Result<Datum> Finalize() override {
1258+
ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
1259+
return std::make_shared<BooleanArray>(num_groups_, std::move(seen));
1260+
}
1261+
1262+
std::shared_ptr<DataType> out_type() const override { return boolean(); }
1263+
1264+
int64_t num_groups_ = 0;
1265+
ScalarAggregateOptions options_;
1266+
TypedBufferBuilder<bool> seen_;
1267+
};
1268+
11521269
} // namespace
11531270

11541271
Result<std::vector<const HashAggregateKernel*>> GetKernels(
@@ -1426,6 +1543,14 @@ const FunctionDoc hash_min_max_doc{
14261543
"This can be changed through ScalarAggregateOptions."),
14271544
{"array", "group_id_array"},
14281545
"ScalarAggregateOptions"};
1546+
1547+
const FunctionDoc hash_any_doc{"Test whether any element evaluates to true",
1548+
("Null values are ignored."),
1549+
{"array", "group_id_array"}};
1550+
1551+
const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true",
1552+
("Null values are ignored."),
1553+
{"array", "group_id_array"}};
14291554
} // namespace
14301555

14311556
void RegisterHashAggregateBasic(FunctionRegistry* registry) {
@@ -1460,6 +1585,20 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
14601585
DCHECK_OK(AddHashAggKernels(NumericTypes(), GroupedMinMaxFactory::Make, func.get()));
14611586
DCHECK_OK(registry->AddFunction(std::move(func)));
14621587
}
1588+
1589+
{
1590+
auto func = std::make_shared<HashAggregateFunction>("hash_any", Arity::Binary(),
1591+
&hash_any_doc);
1592+
DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit<GroupedAnyImpl>)));
1593+
DCHECK_OK(registry->AddFunction(std::move(func)));
1594+
}
1595+
1596+
{
1597+
auto func = std::make_shared<HashAggregateFunction>("hash_all", Arity::Binary(),
1598+
&hash_all_doc);
1599+
DCHECK_OK(func->AddKernel(MakeKernel(boolean(), HashAggregateInit<GroupedAllImpl>)));
1600+
DCHECK_OK(registry->AddFunction(std::move(func)));
1601+
}
14631602
}
14641603

14651604
} // namespace internal

cpp/src/arrow/compute/kernels/hash_aggregate_test.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,55 @@ TEST(GroupBy, MinMaxOnly) {
705705
}
706706
}
707707

708+
TEST(GroupBy, AnyAndAll) {
709+
for (bool use_threads : {true, false}) {
710+
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
711+
712+
auto table =
713+
TableFromJSON(schema({field("argument", boolean()), field("key", int64())}), {R"([
714+
[true, 1],
715+
[null, 1]
716+
])",
717+
R"([
718+
[false, 2],
719+
[null, 3],
720+
[false, null],
721+
[true, 1],
722+
[true, 2]
723+
])",
724+
R"([
725+
[true, 2],
726+
[false, null],
727+
[null, 3]
728+
])"});
729+
730+
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
731+
internal::GroupBy({table->GetColumnByName("argument"),
732+
table->GetColumnByName("argument")},
733+
{table->GetColumnByName("key")},
734+
{
735+
{"hash_any", nullptr},
736+
{"hash_all", nullptr},
737+
},
738+
use_threads));
739+
SortBy({"key_0"}, &aggregated_and_grouped);
740+
741+
AssertDatumsEqual(ArrayFromJSON(struct_({
742+
field("hash_any", boolean()),
743+
field("hash_all", boolean()),
744+
field("key_0", int64()),
745+
}),
746+
R"([
747+
[true, true, 1],
748+
[true, false, 2],
749+
[false, true, 3],
750+
[false, false, null]
751+
])"),
752+
aggregated_and_grouped,
753+
/*verbose=*/true);
754+
}
755+
}
756+
708757
TEST(GroupBy, CountAndSum) {
709758
auto batch = RecordBatchFromJSON(
710759
schema({field("argument", float64()), field("key", int64())}), R"([

0 commit comments

Comments
 (0)