Skip to content

Commit 34834f7

Browse files
andrioniemkornfield
authored andcommitted
First try at implementing a CountValues kernel
1 parent 7ddad36 commit 34834f7

3 files changed

Lines changed: 208 additions & 1 deletion

File tree

cpp/src/arrow/compute/kernels/hash-test.cc

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,23 @@ void CheckUnique(FunctionContext* ctx, const shared_ptr<DataType>& type,
6464
ASSERT_ARRAYS_EQUAL(*expected, *result);
6565
}
6666

67+
template <typename Type, typename T>
68+
void CheckCountValues(FunctionContext* ctx, const shared_ptr<DataType>& type,
69+
const vector<T>& in_values, const vector<bool>& in_is_valid,
70+
const vector<T>& out_values, const vector<bool>& out_is_valid,
71+
const vector<int64_t>& out_counts) {
72+
shared_ptr<Array> input = _MakeArray<Type, T>(type, in_values, in_is_valid);
73+
shared_ptr<Array> ex_values = _MakeArray<Type, T>(type, out_values, out_is_valid);
74+
shared_ptr<Array> ex_counts =
75+
_MakeArray<Int64Type, int64_t>(int64(), out_counts, out_is_valid);
76+
77+
shared_ptr<Array> result_values;
78+
shared_ptr<Array> result_counts;
79+
ASSERT_OK(CountValues(ctx, Datum(input), &result_values, &result_counts));
80+
ASSERT_ARRAYS_EQUAL(*ex_values, *result_values);
81+
ASSERT_ARRAYS_EQUAL(*ex_counts, *result_counts);
82+
}
83+
6784
template <typename Type, typename T>
6885
void CheckDictEncode(FunctionContext* ctx, const shared_ptr<DataType>& type,
6986
const vector<T>& in_values, const vector<bool>& in_is_valid,
@@ -104,6 +121,14 @@ TYPED_TEST(TestHashKernelPrimitive, Unique) {
104121
{3, 1}, {});
105122
}
106123

124+
TYPED_TEST(TestHashKernelPrimitive, CountValues) {
125+
using T = typename TypeParam::c_type;
126+
auto type = TypeTraits<TypeParam>::type_singleton();
127+
CheckCountValues<TypeParam, T>(&this->ctx_, type, {2, 1, 2, 1, 2, 3, 4},
128+
{true, false, true, true, true, true, false}, {2, 1, 3},
129+
{}, {3, 1, 1});
130+
}
131+
107132
TYPED_TEST(TestHashKernelPrimitive, DictEncode) {
108133
using T = typename TypeParam::c_type;
109134
auto type = TypeTraits<TypeParam>::type_singleton();
@@ -121,19 +146,21 @@ TYPED_TEST(TestHashKernelPrimitive, PrimitiveResizeTable) {
121146
vector<T> values;
122147
vector<T> uniques;
123148
vector<int32_t> indices;
149+
vector<int64_t> counts;
124150
for (int64_t i = 0; i < kTotalValues * kRepeats; i++) {
125151
const auto val = static_cast<T>(i % kTotalValues);
126152
values.push_back(val);
127153

128154
if (i < kTotalValues) {
129155
uniques.push_back(val);
156+
counts.push_back(kRepeats);
130157
}
131158
indices.push_back(static_cast<int32_t>(i % kTotalValues));
132159
}
133160

134161
auto type = TypeTraits<TypeParam>::type_singleton();
135162
CheckUnique<TypeParam, T>(&this->ctx_, type, values, {}, uniques, {});
136-
163+
CheckCountValues<TypeParam, T>(&this->ctx_, type, values, {}, uniques, {}, counts);
137164
CheckDictEncode<TypeParam, T>(&this->ctx_, type, values, {}, uniques, {}, indices);
138165
}
139166

@@ -149,6 +176,19 @@ TEST_F(TestHashKernel, UniqueTimeTimestamp) {
149176
{});
150177
}
151178

179+
TEST_F(TestHashKernel, CountValuesTimeTimestamp) {
180+
CheckCountValues<Time32Type, int32_t>(&this->ctx_, time32(TimeUnit::SECOND),
181+
{2, 1, 2, 1}, {true, false, true, true}, {2, 1},
182+
{}, {2, 1});
183+
184+
CheckCountValues<Time64Type, int64_t>(&this->ctx_, time64(TimeUnit::NANO), {2, 1, 2, 1},
185+
{true, false, true, true}, {2, 1}, {}, {2, 1});
186+
187+
CheckCountValues<TimestampType, int64_t>(&this->ctx_, timestamp(TimeUnit::NANO),
188+
{2, 1, 2, 1}, {true, false, true, true},
189+
{2, 1}, {}, {2, 1});
190+
}
191+
152192
TEST_F(TestHashKernel, UniqueBoolean) {
153193
CheckUnique<BooleanType, bool>(&this->ctx_, boolean(), {true, true, false, true},
154194
{true, false, true, true}, {true, false}, {});
@@ -164,6 +204,23 @@ TEST_F(TestHashKernel, UniqueBoolean) {
164204
{false, true}, {});
165205
}
166206

207+
TEST_F(TestHashKernel, CountValuesBoolean) {
208+
CheckCountValues<BooleanType, bool>(&this->ctx_, boolean(), {true, true, false, true},
209+
{true, false, true, true}, {true, false}, {},
210+
{2, 1});
211+
212+
CheckCountValues<BooleanType, bool>(&this->ctx_, boolean(), {false, true, false, true},
213+
{true, false, true, true}, {false, true}, {},
214+
{2, 1});
215+
216+
// No nulls
217+
CheckCountValues<BooleanType, bool>(&this->ctx_, boolean(), {true, true, false, true},
218+
{}, {true, false}, {}, {3, 1});
219+
220+
CheckCountValues<BooleanType, bool>(&this->ctx_, boolean(), {false, true, false, true},
221+
{}, {false, true}, {}, {2, 2});
222+
}
223+
167224
TEST_F(TestHashKernel, DictEncodeBoolean) {
168225
CheckDictEncode<BooleanType, bool>(
169226
&this->ctx_, boolean(), {true, true, false, true, false},
@@ -192,6 +249,16 @@ TEST_F(TestHashKernel, UniqueBinary) {
192249
{true, false, true, true}, {"test", "test2"}, {});
193250
}
194251

252+
TEST_F(TestHashKernel, CountValuesBinary) {
253+
CheckCountValues<BinaryType, std::string>(
254+
&this->ctx_, binary(), {"test", "", "test2", "test"}, {true, false, true, true},
255+
{"test", "test2"}, {}, {2, 1});
256+
257+
CheckCountValues<StringType, std::string>(
258+
&this->ctx_, utf8(), {"test", "", "test2", "test"}, {true, false, true, true},
259+
{"test", "test2"}, {}, {2, 1});
260+
}
261+
195262
TEST_F(TestHashKernel, DictEncodeBinary) {
196263
CheckDictEncode<BinaryType, std::string>(
197264
&this->ctx_, binary(), {"test", "", "test2", "test", "baz"},
@@ -214,6 +281,7 @@ TEST_F(TestHashKernel, BinaryResizeTable) {
214281
vector<std::string> values;
215282
vector<std::string> uniques;
216283
vector<int32_t> indices;
284+
vector<int64_t> counts;
217285
char buf[20] = "test";
218286

219287
for (int32_t i = 0; i < kTotalValues * kRepeats; i++) {
@@ -224,6 +292,7 @@ TEST_F(TestHashKernel, BinaryResizeTable) {
224292

225293
if (i < kTotalValues) {
226294
uniques.push_back(values.back());
295+
counts.push_back(kRepeats);
227296
}
228297
indices.push_back(index);
229298
}
@@ -233,6 +302,8 @@ TEST_F(TestHashKernel, BinaryResizeTable) {
233302
indices);
234303

235304
CheckUnique<StringType, std::string>(&this->ctx_, utf8(), values, {}, uniques, {});
305+
CheckCountValues<FixedSizeBinaryType, std::string>(&this->ctx_, type, values, {},
306+
uniques, {}, counts);
236307
CheckDictEncode<StringType, std::string>(&this->ctx_, utf8(), values, {}, uniques, {},
237308
indices);
238309
}
@@ -291,6 +362,15 @@ TEST_F(TestHashKernel, UniqueDecimal) {
291362
{true, false, true, true}, expected, {});
292363
}
293364

365+
TEST_F(TestHashKernel, CountValuesDecimal) {
366+
vector<Decimal128> values{12, 12, 11, 12};
367+
vector<Decimal128> expected{12, 11};
368+
369+
CheckCountValues<Decimal128Type, Decimal128>(&this->ctx_, decimal(2, 0), values,
370+
{true, false, true, true}, expected, {},
371+
{2, 1});
372+
}
373+
294374
TEST_F(TestHashKernel, DictEncodeDecimal) {
295375
vector<Decimal128> values{12, 12, 11, 12, 13};
296376
vector<Decimal128> expected{12, 11, 13};
@@ -311,6 +391,9 @@ TEST_F(TestHashKernel, ChunkedArrayInvoke) {
311391
vector<std::string> dict_values = {"foo", "bar", "baz", "quuux"};
312392
auto ex_dict = _MakeArray<StringType, std::string>(type, dict_values, {});
313393

394+
vector<int64_t> counts = {3, 2, 1, 1};
395+
auto ex_counts = _MakeArray<Int64Type, int64_t>(int64(), counts, {});
396+
314397
ArrayVector arrays = {a1, a2};
315398
auto carr = std::make_shared<ChunkedArray>(arrays);
316399

@@ -329,6 +412,14 @@ TEST_F(TestHashKernel, ChunkedArrayInvoke) {
329412
std::make_shared<DictionaryArray>(dict_type, i2)};
330413
auto dict_carr = std::make_shared<ChunkedArray>(dict_arrays);
331414

415+
// Unique counts
416+
shared_ptr<Array> cv_uniques;
417+
shared_ptr<Array> cv_counts;
418+
ASSERT_OK(CountValues(&this->ctx_, Datum(carr), &cv_uniques, &cv_counts));
419+
ASSERT_ARRAYS_EQUAL(*ex_dict, *cv_uniques);
420+
ASSERT_ARRAYS_EQUAL(*ex_counts, *cv_counts);
421+
422+
// Dictionary encode
332423
Datum encoded_out;
333424
ASSERT_OK(DictionaryEncode(&this->ctx_, carr, &encoded_out));
334425
ASSERT_EQ(Datum::CHUNKED_ARRAY, encoded_out.kind());

cpp/src/arrow/compute/kernels/hash.cc

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,50 @@ class UniqueAction : public ActionBase {
9595
std::shared_ptr<DataType> out_type() const { return type_; }
9696
};
9797

98+
// ----------------------------------------------------------------------
99+
// Count values implementation
100+
101+
template <typename Type>
102+
class CountValuesImpl : public HashTableKernel<Type, CountValuesImpl<Type>> {
103+
public:
104+
static constexpr bool allow_expand = true;
105+
using Base = HashTableKernel<Type, CountValuesImpl>;
106+
107+
CountValuesImpl(const std::shared_ptr<DataType>& type, MemoryPool* pool)
108+
: Base(type, pool) {}
109+
110+
Status Reserve(const int64_t length) {
111+
counts_.reserve(length);
112+
return Status::OK();
113+
}
114+
115+
void ObserveNull() {}
116+
117+
void ObserveFound(const hash_slot_t slot) { counts_[slot]++; }
118+
119+
void ObserveNotFound(const hash_slot_t slot) { counts_.emplace_back(1); }
120+
121+
Status DoubleSize() { return Base::DoubleTableSize(); }
122+
123+
Status Flush(Datum* out) override {
124+
Int64Builder builder(Base::pool_);
125+
std::shared_ptr<ArrayData> result;
126+
127+
for (const int64_t value : counts_) {
128+
RETURN_NOT_OK(builder.Append(value));
129+
}
130+
131+
RETURN_NOT_OK(builder.FinishInternal(&result));
132+
out->value = std::move(result);
133+
return Status::OK();
134+
}
135+
136+
using Base::Append;
137+
138+
private:
139+
std::vector<int64_t> counts_;
140+
};
141+
98142
// ----------------------------------------------------------------------
99143
// Dictionary encode implementation
100144

@@ -368,6 +412,48 @@ Status GetDictionaryEncodeKernel(FunctionContext* ctx,
368412
return Status::OK();
369413
}
370414

415+
Status GetCountValuesKernel(FunctionContext* ctx, const std::shared_ptr<DataType>& type,
416+
std::unique_ptr<HashKernel>* out) {
417+
std::unique_ptr<HashTable> hasher;
418+
419+
#define COUNT_VALUES_CASE(InType) \
420+
case InType::type_id: \
421+
hasher.reset(new CountValuesImpl<InType>(type, ctx->memory_pool())); \
422+
break
423+
424+
switch (type->id()) {
425+
COUNT_VALUES_CASE(NullType);
426+
COUNT_VALUES_CASE(BooleanType);
427+
COUNT_VALUES_CASE(UInt8Type);
428+
COUNT_VALUES_CASE(Int8Type);
429+
COUNT_VALUES_CASE(UInt16Type);
430+
COUNT_VALUES_CASE(Int16Type);
431+
COUNT_VALUES_CASE(UInt32Type);
432+
COUNT_VALUES_CASE(Int32Type);
433+
COUNT_VALUES_CASE(UInt64Type);
434+
COUNT_VALUES_CASE(Int64Type);
435+
COUNT_VALUES_CASE(FloatType);
436+
COUNT_VALUES_CASE(DoubleType);
437+
COUNT_VALUES_CASE(Date32Type);
438+
COUNT_VALUES_CASE(Date64Type);
439+
COUNT_VALUES_CASE(Time32Type);
440+
COUNT_VALUES_CASE(Time64Type);
441+
COUNT_VALUES_CASE(TimestampType);
442+
COUNT_VALUES_CASE(BinaryType);
443+
COUNT_VALUES_CASE(StringType);
444+
COUNT_VALUES_CASE(FixedSizeBinaryType);
445+
COUNT_VALUES_CASE(Decimal128Type);
446+
default:
447+
break;
448+
}
449+
450+
#undef COUNT_VALUES_CASE
451+
452+
CHECK_IMPLEMENTED(hasher, "count-values", type);
453+
out->reset(new HashKernelImpl(std::move(hasher)));
454+
return Status::OK();
455+
}
456+
371457
namespace {
372458

373459
Status InvokeHash(FunctionContext* ctx, HashKernel* func, const Datum& value,
@@ -415,5 +501,18 @@ Status DictionaryEncode(FunctionContext* ctx, const Datum& value, Datum* out) {
415501
return Status::OK();
416502
}
417503

504+
Status CountValues(FunctionContext* ctx, const Datum& value,
505+
std::shared_ptr<Array>* out_uniques,
506+
std::shared_ptr<Array>* out_counts) {
507+
std::unique_ptr<HashKernel> func;
508+
RETURN_NOT_OK(GetCountValuesKernel(ctx, value.type(), &func));
509+
510+
std::vector<Datum> counts_datum;
511+
RETURN_NOT_OK(InvokeHash(ctx, func.get(), value, &counts_datum, out_uniques));
512+
513+
*out_counts = MakeArray(counts_datum.back().array());
514+
return Status::OK();
515+
}
516+
418517
} // namespace compute
419518
} // namespace arrow

cpp/src/arrow/compute/kernels/hash.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ Status GetDictionaryEncodeKernel(FunctionContext* ctx,
5656
const std::shared_ptr<DataType>& type,
5757
std::unique_ptr<HashKernel>* kernel);
5858

59+
ARROW_EXPORT
60+
Status GetCountValuesKernel(FunctionContext* ctx, const std::shared_ptr<DataType>& type,
61+
std::unique_ptr<HashKernel>* kernel);
62+
5963
/// \brief Compute unique elements from an array-like object
6064
/// \param[in] context the FunctionContext
6165
/// \param[in] datum array-like input
@@ -76,6 +80,19 @@ Status Unique(FunctionContext* context, const Datum& datum, std::shared_ptr<Arra
7680
ARROW_EXPORT
7781
Status DictionaryEncode(FunctionContext* context, const Datum& data, Datum* out);
7882

83+
/// \brief Return counts of unique elements from an array-like object
84+
/// \param[in] context the FunctionContext
85+
/// \param[in] value array-like input
86+
/// \param[out] out_uniques unique elements as Array
87+
/// \param[out] out_counts counts per element as Array, same shape as out_uniques
88+
///
89+
/// \since 0.10.0
90+
/// \note API not yet finalized
91+
ARROW_EXPORT
92+
Status CountValues(FunctionContext* context, const Datum& value,
93+
std::shared_ptr<Array>* out_uniques,
94+
std::shared_ptr<Array>* out_counts);
95+
7996
// TODO(wesm): Define API for incremental dictionary encoding
8097

8198
// TODO(wesm): Define API for regularizing DictionaryArray objects with

0 commit comments

Comments
 (0)