diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 85d2557acb7..ef0c80de98b 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -30,6 +30,7 @@ #include "arrow/util/bit_block_counter.h" #include "arrow/util/checked_cast.h" #include "arrow/util/optional.h" +#include "arrow/visitor_inline.h" namespace arrow { @@ -38,6 +39,21 @@ using internal::checked_cast; namespace compute { namespace internal { +// Visit all physical types for which sorting is implemented. +#define VISIT_PHYSICAL_TYPES(VISIT) \ + VISIT(Int8Type) \ + VISIT(Int16Type) \ + VISIT(Int32Type) \ + VISIT(Int64Type) \ + VISIT(UInt8Type) \ + VISIT(UInt16Type) \ + VISIT(UInt32Type) \ + VISIT(UInt64Type) \ + VISIT(FloatType) \ + VISIT(DoubleType) \ + VISIT(BinaryType) \ + VISIT(LargeBinaryType) + namespace { // The target chunk in a chunked array. @@ -142,15 +158,20 @@ struct ChunkedArrayResolver { // (such as cached raw values pointer) in a separate hierarchy of // physical accessors, but doing so ends up too cumbersome. // Instead, we simply create the desired concrete Array objects. +std::shared_ptr GetPhysicalArray(const Array& array, + const std::shared_ptr& physical_type) { + auto new_data = array.data()->Copy(); + new_data->type = physical_type; + return MakeArray(std::move(new_data)); +} + ArrayVector GetPhysicalChunks(const ChunkedArray& chunked_array, const std::shared_ptr& physical_type) { const auto& chunks = chunked_array.chunks(); ArrayVector physical(chunks.size()); std::transform(chunks.begin(), chunks.end(), physical.begin(), [&](const std::shared_ptr& array) { - auto new_data = array->data()->Copy(); - new_data->type = physical_type; - return MakeArray(std::move(new_data)); + return GetPhysicalArray(*array, physical_type); }); return physical; } @@ -634,20 +655,9 @@ class ChunkedArraySorter : public TypeVisitor { Status Sort() { return physical_type_->Accept(this); } #define VISIT(TYPE) \ - Status Visit(const TYPE##Type& type) override { return SortInternal(); } - - VISIT(Int8) - VISIT(Int16) - VISIT(Int32) - VISIT(Int64) - VISIT(UInt8) - VISIT(UInt16) - VISIT(UInt32) - VISIT(UInt64) - VISIT(Float) - VISIT(Double) - VISIT(Binary) - VISIT(LargeBinary) + Status Visit(const TYPE& type) override { return SortInternal(); } + + VISIT_PHYSICAL_TYPES(VISIT) #undef VISIT @@ -804,6 +814,517 @@ class ChunkedArraySorter : public TypeVisitor { ExecContext* ctx_; }; +// ---------------------------------------------------------------------- +// Record batch sorting implementation(s) + +// Visit contiguous ranges of equal values. All entries are assumed +// to be non-null. +template +void VisitConstantRanges(const ArrayType& array, uint64_t* indices_begin, + uint64_t* indices_end, Visitor&& visit) { + if (indices_begin == indices_end) { + return; + } + auto range_start = indices_begin; + auto range_cur = range_start; + auto last_value = array.GetView(*range_cur); + while (++range_cur != indices_end) { + auto v = array.GetView(*range_cur); + if (v != last_value) { + visit(range_start, range_cur); + range_start = range_cur; + last_value = v; + } + } + if (range_start != range_cur) { + visit(range_start, range_cur); + } +} + +// A sorter for a single column of a RecordBatch, deferring to the next column +// for ranges of equal values. +class RecordBatchColumnSorter { + public: + explicit RecordBatchColumnSorter(RecordBatchColumnSorter* next_column = nullptr) + : next_column_(next_column) {} + virtual ~RecordBatchColumnSorter() {} + + virtual void SortRange(uint64_t* indices_begin, uint64_t* indices_end) = 0; + + protected: + RecordBatchColumnSorter* next_column_; +}; + +template +class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter { + public: + using ArrayType = typename TypeTraits::ArrayType; + + ConcreteRecordBatchColumnSorter(std::shared_ptr array, SortOrder order, + RecordBatchColumnSorter* next_column = nullptr) + : RecordBatchColumnSorter(next_column), + owned_array_(std::move(array)), + array_(checked_cast(*owned_array_)), + order_(order), + null_count_(array_.null_count()) {} + + void SortRange(uint64_t* indices_begin, uint64_t* indices_end) { + constexpr int64_t offset = 0; + uint64_t* nulls_begin; + if (null_count_ == 0) { + nulls_begin = indices_end; + } else { + // NOTE that null_count_ is merely an upper bound on the number of nulls + // in this particular range. + nulls_begin = PartitionNullsOnly(indices_begin, indices_end, + array_, offset); + DCHECK_LE(indices_end - nulls_begin, null_count_); + } + uint64_t* null_likes_begin = PartitionNullLikes( + indices_begin, nulls_begin, array_, offset); + + // TODO This is roughly the same as ArrayCompareSorter. + // Also, we would like to use a counting sort if possible. This requires + // a counting sort compatible with indirect indexing. + if (order_ == SortOrder::Ascending) { + std::stable_sort( + indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) { + return array_.GetView(left - offset) < array_.GetView(right - offset); + }); + } else { + std::stable_sort( + indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) { + // We don't use 'left > right' here to reduce required operator. + // If we use 'right < left' here, '<' is only required. + return array_.GetView(right - offset) < array_.GetView(left - offset); + }); + } + + if (next_column_ != nullptr) { + // Visit all ranges of equal values in this column and sort them on + // the next column. + SortNextColumn(null_likes_begin, nulls_begin); + SortNextColumn(nulls_begin, indices_end); + VisitConstantRanges(array_, indices_begin, null_likes_begin, + [&](uint64_t* range_start, uint64_t* range_end) { + SortNextColumn(range_start, range_end); + }); + } + } + + void SortNextColumn(uint64_t* indices_begin, uint64_t* indices_end) { + // Avoid the cost of a virtual method call in trivial cases + if (indices_end - indices_begin > 1) { + next_column_->SortRange(indices_begin, indices_end); + } + } + + protected: + const std::shared_ptr owned_array_; + const ArrayType& array_; + const SortOrder order_; + const int64_t null_count_; +}; + +// Sort a batch using a single-pass left-to-right radix sort. +class RadixRecordBatchSorter { + public: + RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, + const RecordBatch& batch, const SortOptions& options) + : batch_(batch), + options_(options), + indices_begin_(indices_begin), + indices_end_(indices_end) {} + + Status Sort() { + ARROW_ASSIGN_OR_RAISE(const auto sort_keys, + ResolveSortKeys(batch_, options_.sort_keys)); + + // Create column sorters from right to left + std::vector> column_sorts(sort_keys.size()); + RecordBatchColumnSorter* next_column = nullptr; + for (int64_t i = static_cast(sort_keys.size() - 1); i >= 0; --i) { + ColumnSortFactory factory(sort_keys[i], next_column); + ARROW_ASSIGN_OR_RAISE(column_sorts[i], factory.MakeColumnSort()); + next_column = column_sorts[i].get(); + } + + // Sort from left to right + column_sorts.front()->SortRange(indices_begin_, indices_end_); + return Status::OK(); + } + + protected: + struct ResolvedSortKey { + std::shared_ptr array; + SortOrder order; + }; + + struct ColumnSortFactory { + ColumnSortFactory(const ResolvedSortKey& sort_key, + RecordBatchColumnSorter* next_column) + : physical_type(GetPhysicalType(sort_key.array->type())), + array(GetPhysicalArray(*sort_key.array, physical_type)), + order(sort_key.order), + next_column(next_column) {} + + Result> MakeColumnSort() { + RETURN_NOT_OK(VisitTypeInline(*physical_type, this)); + DCHECK_NE(result, nullptr); + return std::move(result); + } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return VisitGeneric(type); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + Status Visit(const DataType& type) { + return Status::TypeError("Unsupported type for RecordBatch sorting: ", + type.ToString()); + } + + template + Status VisitGeneric(const Type&) { + result.reset(new ConcreteRecordBatchColumnSorter(array, order, next_column)); + return Status::OK(); + } + + std::shared_ptr physical_type; + std::shared_ptr array; + SortOrder order; + RecordBatchColumnSorter* next_column; + std::unique_ptr result; + }; + + static Result> ResolveSortKeys( + const RecordBatch& batch, const std::vector& sort_keys) { + std::vector resolved; + resolved.reserve(sort_keys.size()); + for (const auto& sort_key : sort_keys) { + auto array = batch.GetColumnByName(sort_key.name); + if (!array) { + return Status::Invalid("Nonexistent sort key column: ", sort_key.name); + } + resolved.push_back({std::move(array), sort_key.order}); + } + return resolved; + } + + const RecordBatch& batch_; + const SortOptions& options_; + uint64_t* indices_begin_; + uint64_t* indices_end_; +}; + +// Compare two records in the same RecordBatch or Table +// (indexing is handled through ResolvedSortKey) +template +class MultipleKeyComparator { + public: + explicit MultipleKeyComparator(const std::vector& sort_keys) + : sort_keys_(sort_keys) {} + + Status status() const { return status_; } + + // Returns true if the left-th value should be ordered before the + // right-th value, false otherwise. The start_sort_key_index-th + // sort key and subsequent sort keys are used for comparison. + bool Compare(uint64_t left, uint64_t right, size_t start_sort_key_index) { + current_left_ = left; + current_right_ = right; + current_compared_ = 0; + auto num_sort_keys = sort_keys_.size(); + for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) { + current_sort_key_index_ = i; + status_ = VisitTypeInline(*sort_keys_[i].type, this); + // If the left value equals to the right value, we need to + // continue to sort. + if (current_compared_ != 0) { + break; + } + } + return current_compared_ < 0; + } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { \ + current_compared_ = CompareType(); \ + return Status::OK(); \ + } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + Status Visit(const DataType& type) { + return Status::TypeError("Unsupported type for RecordBatch sorting: ", + type.ToString()); + } + + private: + // Compares two records in the same table and returns -1, 0 or 1. + // + // -1: The left is less than the right. + // 0: The left equals to the right. + // 1: The left is greater than the right. + // + // This supports null and NaN. Null is processed in this and NaN + // is processed in CompareTypeValue(). + template + int32_t CompareType() { + using ArrayType = typename TypeTraits::ArrayType; + const auto& sort_key = sort_keys_[current_sort_key_index_]; + auto order = sort_key.order; + const auto chunk_left = sort_key.template GetChunk(current_left_); + const auto chunk_right = sort_key.template GetChunk(current_right_); + if (sort_key.null_count > 0) { + auto is_null_left = chunk_left.IsNull(); + auto is_null_right = chunk_right.IsNull(); + if (is_null_left && is_null_right) { + return 0; + } else if (is_null_left) { + return 1; + } else if (is_null_right) { + return -1; + } + } + return CompareTypeValue(chunk_left, chunk_right, order); + } + + // For non-float types. Value is never NaN. + template + enable_if_t::value, int32_t> CompareTypeValue( + const ResolvedChunk::ArrayType>& chunk_left, + const ResolvedChunk::ArrayType>& chunk_right, + const SortOrder order) { + const auto left = chunk_left.GetView(); + const auto right = chunk_right.GetView(); + int32_t compared; + if (left == right) { + compared = 0; + } else if (left > right) { + compared = 1; + } else { + compared = -1; + } + if (order == SortOrder::Descending) { + compared = -compared; + } + return compared; + } + + // For float types. Value may be NaN. + template + enable_if_t::value, int32_t> CompareTypeValue( + const ResolvedChunk::ArrayType>& chunk_left, + const ResolvedChunk::ArrayType>& chunk_right, + const SortOrder order) { + const auto left = chunk_left.GetView(); + const auto right = chunk_right.GetView(); + auto is_nan_left = std::isnan(left); + auto is_nan_right = std::isnan(right); + if (is_nan_left && is_nan_right) { + return 0; + } else if (is_nan_left) { + return 1; + } else if (is_nan_right) { + return -1; + } + int32_t compared; + if (left == right) { + compared = 0; + } else if (left > right) { + compared = 1; + } else { + compared = -1; + } + if (order == SortOrder::Descending) { + compared = -compared; + } + return compared; + } + + const std::vector& sort_keys_; + Status status_; + int64_t current_left_; + int64_t current_right_; + size_t current_sort_key_index_; + int32_t current_compared_; +}; + +// Sort a batch using a single sort and multiple-key comparisons. +class MultipleKeyRecordBatchSorter : public TypeVisitor { + private: + // Preprocessed sort key. + struct ResolvedSortKey { + ResolvedSortKey(const std::shared_ptr& array, const SortOrder order) + : type(GetPhysicalType(array->type())), + owned_array(GetPhysicalArray(*array, type)), + array(*owned_array), + order(order), + null_count(array->null_count()) {} + + template + ResolvedChunk GetChunk(int64_t index) const { + return {&checked_cast(array), index}; + } + + const std::shared_ptr type; + std::shared_ptr owned_array; + const Array& array; + SortOrder order; + int64_t null_count; + }; + + using Comparator = MultipleKeyComparator; + + public: + MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, + const RecordBatch& batch, const SortOptions& options) + : indices_begin_(indices_begin), + indices_end_(indices_end), + sort_keys_(ResolveSortKeys(batch, options.sort_keys, &status_)), + comparator_(sort_keys_) {} + + // This is optimized for the first sort key. The first sort key sort + // is processed in this class. The second and following sort keys + // are processed in Comparator. + Status Sort() { + RETURN_NOT_OK(status_); + return sort_keys_[0].type->Accept(this); + } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) override { return SortInternal(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + private: + static std::vector ResolveSortKeys( + const RecordBatch& batch, const std::vector& sort_keys, Status* status) { + std::vector resolved; + for (const auto& sort_key : sort_keys) { + auto array = batch.GetColumnByName(sort_key.name); + if (!array) { + *status = Status::Invalid("Nonexistent sort key column: ", sort_key.name); + break; + } + resolved.emplace_back(array, sort_key.order); + } + return resolved; + } + + template + Status SortInternal() { + using ArrayType = typename TypeTraits::ArrayType; + + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + const ArrayType& array = checked_cast(first_sort_key.array); + auto nulls_begin = indices_end_; + nulls_begin = PartitionNullsInternal(first_sort_key); + // Sort first-key non-nulls + std::stable_sort(indices_begin_, nulls_begin, [&](uint64_t left, uint64_t right) { + // Both values are never null nor NaN + // (otherwise they've been partitioned away above). + const auto value_left = array.GetView(left); + const auto value_right = array.GetView(right); + if (value_left != value_right) { + bool compared = value_left < value_right; + if (first_sort_key.order == SortOrder::Ascending) { + return compared; + } else { + return !compared; + } + } + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left, right, 1); + }); + return comparator_.status(); + } + + // Behaves like PatitionNulls() but this supports multiple sort keys. + // + // For non-float types. + template + enable_if_t::value, uint64_t*> PartitionNullsInternal( + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits::ArrayType; + if (first_sort_key.null_count == 0) { + return indices_end_; + } + const ArrayType& array = checked_cast(first_sort_key.array); + StablePartitioner partitioner; + auto nulls_begin = partitioner(indices_begin_, indices_end_, + [&](uint64_t index) { return !array.IsNull(index); }); + // Sort all nulls by second and following sort keys + // TODO: could we instead run an independent sort from the second key on + // this slice? + if (nulls_begin != indices_end_) { + auto& comparator = comparator_; + std::stable_sort(nulls_begin, indices_end_, + [&comparator](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + } + return nulls_begin; + } + + // Behaves like PatitionNulls() but this supports multiple sort keys. + // + // For float types. + template + enable_if_t::value, uint64_t*> PartitionNullsInternal( + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits::ArrayType; + const ArrayType& array = checked_cast(first_sort_key.array); + StablePartitioner partitioner; + uint64_t* nulls_begin; + if (first_sort_key.null_count == 0) { + nulls_begin = indices_end_; + } else { + nulls_begin = partitioner(indices_begin_, indices_end_, + [&](uint64_t index) { return !array.IsNull(index); }); + } + uint64_t* nans_and_nulls_begin = + partitioner(indices_begin_, nulls_begin, + [&](uint64_t index) { return !std::isnan(array.GetView(index)); }); + auto& comparator = comparator_; + if (nans_and_nulls_begin != nulls_begin) { + // Sort all NaNs by the second and following sort keys. + // TODO: could we instead run an independent sort from the second key on + // this slice? + std::stable_sort(nans_and_nulls_begin, nulls_begin, + [&comparator](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + } + if (nulls_begin != indices_end_) { + // Sort all nulls by the second and following sort keys. + // TODO: could we instead run an independent sort from the second key on + // this slice? + std::stable_sort(nulls_begin, indices_end_, + [&comparator](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + } + return nans_and_nulls_begin; + } + + uint64_t* indices_begin_; + uint64_t* indices_end_; + Status status_; + std::vector sort_keys_; + Comparator comparator_; +}; + // ---------------------------------------------------------------------- // Table sorting implementations @@ -834,6 +1355,10 @@ class TableRadixSorter { // Sort a table using a single sort and multiple-key comparisons. class MultipleKeyTableSorter : public TypeVisitor { private: + // TODO instead of resolving chunks for each column independently, we could + // split the table into RecordBatches and pay the cost of chunked indexing + // at the first column only. + // Preprocessed sort key. struct ResolvedSortKey { ResolvedSortKey(const ChunkedArray& chunked_array, const SortOrder order) @@ -861,229 +1386,76 @@ class MultipleKeyTableSorter : public TypeVisitor { const ChunkedArrayResolver resolver; }; - // Compare two records in the same table. - class Comparer : public TypeVisitor { - public: - Comparer(const Table& table, const std::vector& sort_keys) - : TypeVisitor(), - status_(Status::OK()), - sort_keys_(ResolveSortKeys(table, sort_keys, &status_)) {} - - Status status() { return status_; } - - const std::vector& sort_keys() { return sort_keys_; } - - // Returns true if the left-th value should be ordered before the - // right-th value, false otherwise. The start_sort_key_index-th - // sort key and subsequent sort keys are used for comparison. - bool Compare(uint64_t left, uint64_t right, size_t start_sort_key_index) { - current_left_ = left; - current_right_ = right; - current_compared_ = 0; - auto num_sort_keys = sort_keys_.size(); - for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) { - current_sort_key_index_ = i; - status_ = sort_keys_[i].type->Accept(this); - // If the left value equals to the right value, we need to - // continue to sort. - if (current_compared_ != 0) { - break; - } - } - return current_compared_ < 0; - } - -#define VISIT(TYPE) \ - Status Visit(const TYPE##Type& type) override { \ - current_compared_ = CompareType(); \ - return Status::OK(); \ - } - - VISIT(Int8) - VISIT(Int16) - VISIT(Int32) - VISIT(Int64) - VISIT(UInt8) - VISIT(UInt16) - VISIT(UInt32) - VISIT(UInt64) - VISIT(Float) - VISIT(Double) - VISIT(Binary) - VISIT(LargeBinary) - -#undef VISIT - - private: - // Compares two records in the same table and returns -1, 0 or 1. - // - // -1: The left is less than the right. - // 0: The left equals to the right. - // 1: The left is greater than the right. - // - // This supports null and NaN. Null is processed in this and NaN - // is processed in CompareTypeValue(). - template - int32_t CompareType() { - using ArrayType = typename TypeTraits::ArrayType; - const auto& sort_key = sort_keys_[current_sort_key_index_]; - auto order = sort_key.order; - const auto chunk_left = sort_key.GetChunk(current_left_); - const auto chunk_right = sort_key.GetChunk(current_right_); - if (sort_key.null_count > 0) { - auto is_null_left = chunk_left.IsNull(); - auto is_null_right = chunk_right.IsNull(); - if (is_null_left && is_null_right) { - return 0; - } else if (is_null_left) { - return 1; - } else if (is_null_right) { - return -1; - } - } - return CompareTypeValue(chunk_left, chunk_right, order); - } - - // For non-float types. Value is never NaN. - template - enable_if_t::value, int32_t> CompareTypeValue( - const ResolvedChunk::ArrayType>& chunk_left, - const ResolvedChunk::ArrayType>& chunk_right, - const SortOrder order) { - const auto left = chunk_left.GetView(); - const auto right = chunk_right.GetView(); - int32_t compared; - if (left == right) { - compared = 0; - } else if (left > right) { - compared = 1; - } else { - compared = -1; - } - if (order == SortOrder::Descending) { - compared = -compared; - } - return compared; - } - - // For float types. Value may be NaN. - template - enable_if_t::value, int32_t> CompareTypeValue( - const ResolvedChunk::ArrayType>& chunk_left, - const ResolvedChunk::ArrayType>& chunk_right, - const SortOrder order) { - const auto left = chunk_left.GetView(); - const auto right = chunk_right.GetView(); - auto is_nan_left = std::isnan(left); - auto is_nan_right = std::isnan(right); - if (is_nan_left && is_nan_right) { - return 0; - } else if (is_nan_left) { - return 1; - } else if (is_nan_right) { - return -1; - } - int32_t compared; - if (left == right) { - compared = 0; - } else if (left > right) { - compared = 1; - } else { - compared = -1; - } - if (order == SortOrder::Descending) { - compared = -compared; - } - return compared; - } - - static std::vector ResolveSortKeys( - const Table& table, const std::vector& sort_keys, Status* status) { - std::vector resolved; - resolved.reserve(sort_keys.size()); - for (const auto& sort_key : sort_keys) { - const auto& chunked_array = table.GetColumnByName(sort_key.name); - if (!chunked_array) { - *status = Status::Invalid("Nonexistent sort key column: ", sort_key.name); - break; - } - resolved.emplace_back(*chunked_array, sort_key.order); - } - return resolved; - } - - Status status_; - const std::vector sort_keys_; - int64_t current_left_; - int64_t current_right_; - size_t current_sort_key_index_; - int32_t current_compared_; - }; + using Comparator = MultipleKeyComparator; public: MultipleKeyTableSorter(uint64_t* indices_begin, uint64_t* indices_end, const Table& table, const SortOptions& options) : indices_begin_(indices_begin), indices_end_(indices_end), - comparer_(table, options.sort_keys) {} + sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)), + comparator_(sort_keys_) {} // This is optimized for the first sort key. The first sort key sort // is processed in this class. The second and following sort keys - // are processed in Comparer. + // are processed in Comparator. Status Sort() { - ARROW_RETURN_NOT_OK(comparer_.status()); - return comparer_.sort_keys()[0].type->Accept(this); + ARROW_RETURN_NOT_OK(status_); + return sort_keys_[0].type->Accept(this); } #define VISIT(TYPE) \ - Status Visit(const TYPE##Type& type) override { return SortInternal(); } - - VISIT(Int8) - VISIT(Int16) - VISIT(Int32) - VISIT(Int64) - VISIT(UInt8) - VISIT(UInt16) - VISIT(UInt32) - VISIT(UInt64) - VISIT(Float) - VISIT(Double) - VISIT(Binary) - VISIT(LargeBinary) + Status Visit(const TYPE& type) override { return SortInternal(); } + + VISIT_PHYSICAL_TYPES(VISIT) #undef VISIT private: + static std::vector ResolveSortKeys( + const Table& table, const std::vector& sort_keys, Status* status) { + std::vector resolved; + resolved.reserve(sort_keys.size()); + for (const auto& sort_key : sort_keys) { + const auto& chunked_array = table.GetColumnByName(sort_key.name); + if (!chunked_array) { + *status = Status::Invalid("Nonexistent sort key column: ", sort_key.name); + break; + } + resolved.emplace_back(*chunked_array, sort_key.order); + } + return resolved; + } + template Status SortInternal() { using ArrayType = typename TypeTraits::ArrayType; - auto& comparer = comparer_; - const auto& first_sort_key = comparer.sort_keys()[0]; + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; auto nulls_begin = indices_end_; nulls_begin = PartitionNullsInternal(first_sort_key); - std::stable_sort(indices_begin_, nulls_begin, - [&first_sort_key, &comparer](uint64_t left, uint64_t right) { - // Both values are never null nor NaN. - auto chunk_left = first_sort_key.GetChunk(left); - auto chunk_right = first_sort_key.GetChunk(right); - auto value_left = chunk_left.GetView(); - auto value_right = chunk_right.GetView(); - if (value_left == value_right) { - // If the left value equals to the right value, - // we need to compare the second and following - // sort keys. - return comparer.Compare(left, right, 1); - } else { - auto compared = value_left < value_right; - if (first_sort_key.order == SortOrder::Ascending) { - return compared; - } else { - return !compared; - } - } - }); - return Status::OK(); + std::stable_sort(indices_begin_, nulls_begin, [&](uint64_t left, uint64_t right) { + // Both values are never null nor NaN. + auto chunk_left = first_sort_key.GetChunk(left); + auto chunk_right = first_sort_key.GetChunk(right); + auto value_left = chunk_left.GetView(); + auto value_right = chunk_right.GetView(); + if (value_left == value_right) { + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left, right, 1); + } else { + auto compared = value_left < value_right; + if (first_sort_key.order == SortOrder::Ascending) { + return compared; + } else { + return !compared; + } + } + }); + return comparator_.status(); } // Behaves like PatitionNulls() but this supports multiple sort keys. @@ -1102,11 +1474,11 @@ class MultipleKeyTableSorter : public TypeVisitor { const auto chunk = first_sort_key.GetChunk(index); return !chunk.IsNull(); }); - auto& comparer = comparer_; - std::stable_sort(nulls_begin, indices_end_, - [&comparer](uint64_t left, uint64_t right) { - return comparer.Compare(left, right, 1); - }); + DCHECK_EQ(indices_end_ - nulls_begin, first_sort_key.null_count); + auto& comparator = comparator_; + std::stable_sort(nulls_begin, indices_end_, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); return nulls_begin; } @@ -1118,45 +1490,37 @@ class MultipleKeyTableSorter : public TypeVisitor { const ResolvedSortKey& first_sort_key) { using ArrayType = typename TypeTraits::ArrayType; StablePartitioner partitioner; + uint64_t* nulls_begin; if (first_sort_key.null_count == 0) { - return partitioner(indices_begin_, indices_end_, [&first_sort_key](uint64_t index) { + nulls_begin = indices_end_; + } else { + nulls_begin = partitioner(indices_begin_, indices_end_, [&](uint64_t index) { const auto chunk = first_sort_key.GetChunk(index); - return !std::isnan(chunk.GetView()); + return !chunk.IsNull(); }); } - auto nans_and_nulls_begin = - partitioner(indices_begin_, indices_end_, [&first_sort_key](uint64_t index) { - const auto chunk = first_sort_key.GetChunk(index); - return !chunk.IsNull() && !std::isnan(chunk.GetView()); - }); - auto nulls_begin = nans_and_nulls_begin; - if (first_sort_key.null_count < static_cast(indices_end_ - nulls_begin)) { - // move nulls after NaN - nulls_begin = partitioner( - nans_and_nulls_begin, indices_end_, [&first_sort_key](uint64_t index) { - const auto chunk = first_sort_key.GetChunk(index); - return !chunk.IsNull(); - }); - } - auto& comparer = comparer_; - if (nans_and_nulls_begin != nulls_begin) { - // Sort all NaNs by the second and following sort keys. - std::stable_sort(nans_and_nulls_begin, nulls_begin, - [&comparer](uint64_t left, uint64_t right) { - return comparer.Compare(left, right, 1); - }); - } + DCHECK_EQ(indices_end_ - nulls_begin, first_sort_key.null_count); + uint64_t* nans_begin = partitioner(indices_begin_, nulls_begin, [&](uint64_t index) { + const auto chunk = first_sort_key.GetChunk(index); + return !std::isnan(chunk.GetView()); + }); + auto& comparator = comparator_; + // Sort all NaNs by the second and following sort keys. + std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); // Sort all nulls by the second and following sort keys. - std::stable_sort(nulls_begin, indices_end_, - [&comparer](uint64_t left, uint64_t right) { - return comparer.Compare(left, right, 1); - }); - return nans_and_nulls_begin; + std::stable_sort(nulls_begin, indices_end_, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + return nans_begin; } uint64_t* indices_begin_; uint64_t* indices_end_; - Comparer comparer_; + Status status_; + std::vector sort_keys_; + Comparator comparator_; }; // ---------------------------------------------------------------------- @@ -1188,9 +1552,7 @@ class SortIndicesMetaFunction : public MetaFunction { return SortIndices(*args[0].chunked_array(), sort_options, ctx); break; case Datum::RECORD_BATCH: { - ARROW_ASSIGN_OR_RAISE(auto table, - Table::FromRecordBatches({args[0].record_batch()})); - return SortIndices(*table, sort_options, ctx); + return SortIndices(*args[0].record_batch(), sort_options, ctx); } break; case Datum::TABLE: return SortIndices(*args[0].table(), sort_options, ctx); @@ -1239,6 +1601,46 @@ class SortIndicesMetaFunction : public MetaFunction { return Datum(out); } + Result SortIndices(const RecordBatch& batch, const SortOptions& options, + ExecContext* ctx) const { + auto n_sort_keys = options.sort_keys.size(); + if (n_sort_keys == 0) { + return Status::Invalid("Must specify one or more sort keys"); + } + if (n_sort_keys == 1) { + auto array = batch.GetColumnByName(options.sort_keys[0].name); + if (!array) { + return Status::Invalid("Nonexistent sort key column: ", + options.sort_keys[0].name); + } + return SortIndices(*array, options, ctx); + } + + auto out_type = uint64(); + auto length = batch.num_rows(); + auto buffer_size = BitUtil::BytesForBits( + length * std::static_pointer_cast(out_type)->bit_width()); + BufferVector buffers(2); + ARROW_ASSIGN_OR_RAISE(buffers[1], + AllocateResizableBuffer(buffer_size, ctx->memory_pool())); + auto out = std::make_shared(out_type, length, buffers, 0); + auto out_begin = out->GetMutableValues(1); + auto out_end = out_begin + length; + std::iota(out_begin, out_end, 0); + + // Radix sorting is consistently faster except when there is a large number + // of sort keys, in which case it can end up degrading catastrophically. + // Cut off above 8 sort keys. + if (n_sort_keys <= 8) { + RadixRecordBatchSorter sorter(out_begin, out_end, batch, options); + ARROW_RETURN_NOT_OK(sorter.Sort()); + } else { + MultipleKeyRecordBatchSorter sorter(out_begin, out_end, batch, options); + ARROW_RETURN_NOT_OK(sorter.Sort()); + } + return Datum(out); + } + Result SortIndices(const Table& table, const SortOptions& options, ExecContext* ctx) const { auto n_sort_keys = options.sort_keys.size(); @@ -1330,6 +1732,8 @@ void RegisterVectorSort(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(part_indices))); } +#undef VISIT_PHYSICAL_TYPES + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc index f48d69e5a24..820c51ba8ec 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc @@ -23,6 +23,7 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/util/benchmark_util.h" +#include "arrow/util/logging.h" namespace arrow { namespace compute { @@ -90,16 +91,15 @@ static void ChunkedArraySortIndicesInt64Wide(benchmark::State& state) { ChunkedArraySortIndicesInt64Benchmark(state, min, max); } -static void TableSortIndicesBenchmark(benchmark::State& state, - const std::shared_ptr& table, +static void DatumSortIndicesBenchmark(benchmark::State& state, const Datum& datum, const SortOptions& options) { for (auto _ : state) { - ABORT_NOT_OK(SortIndices(Datum(*table), options).status()); + ABORT_NOT_OK(SortIndices(datum, options).status()); } } // Extract benchmark args from benchmark::State -struct TableSortIndicesArgs { +struct RecordBatchSortIndicesArgs { // the number of records const int64_t num_records; @@ -109,20 +109,18 @@ struct TableSortIndicesArgs { // the number of columns const int64_t num_columns; - // the number of chunks in each generated column - const int64_t num_chunks; - // Extract args - explicit TableSortIndicesArgs(benchmark::State& state) + explicit RecordBatchSortIndicesArgs(benchmark::State& state) : num_records(state.range(0)), null_proportion(ComputeNullProportion(state.range(1))), num_columns(state.range(2)), - num_chunks(state.range(3)), state_(state) {} - ~TableSortIndicesArgs() { state_.SetItemsProcessed(state_.iterations() * num_records); } + ~RecordBatchSortIndicesArgs() { + state_.SetItemsProcessed(state_.iterations() * num_records); + } - private: + protected: double ComputeNullProportion(int64_t inverse_null_proportion) { if (inverse_null_proportion == 0) { return 0.0; @@ -134,37 +132,86 @@ struct TableSortIndicesArgs { benchmark::State& state_; }; -static void TableSortIndicesInt64(benchmark::State& state, int64_t min, int64_t max) { - TableSortIndicesArgs args(state); +struct TableSortIndicesArgs : public RecordBatchSortIndicesArgs { + // the number of chunks in each generated column + const int64_t num_chunks; - auto rand = random::RandomArrayGenerator(kSeed); - std::vector> fields; + // Extract args + explicit TableSortIndicesArgs(benchmark::State& state) + : RecordBatchSortIndicesArgs(state), num_chunks(state.range(3)) {} +}; + +struct BatchOrTableBenchmarkData { + std::shared_ptr schema; std::vector sort_keys; - std::vector> columns; + ChunkedArrayVector columns; +}; + +BatchOrTableBenchmarkData MakeBatchOrTableBenchmarkDataInt64( + const RecordBatchSortIndicesArgs& args, int64_t num_chunks, int64_t min_value, + int64_t max_value) { + auto rand = random::RandomArrayGenerator(kSeed); + FieldVector fields; + BatchOrTableBenchmarkData data; + for (int64_t i = 0; i < args.num_columns; ++i) { auto name = std::to_string(i); fields.push_back(field(name, int64())); auto order = (i % 2) == 0 ? SortOrder::Ascending : SortOrder::Descending; - sort_keys.emplace_back(name, order); - std::vector> arrays; - if ((args.num_records % args.num_chunks) != 0) { - Status::Invalid("The number of chunks (", args.num_chunks, + data.sort_keys.emplace_back(name, order); + ArrayVector chunks; + if ((args.num_records % num_chunks) != 0) { + Status::Invalid("The number of chunks (", num_chunks, ") must be " "a multiple of the number of records (", args.num_records, ")") .Abort(); } - auto num_records_in_array = args.num_records / args.num_chunks; - for (int64_t j = 0; j < args.num_chunks; ++j) { - arrays.push_back(rand.Int64(num_records_in_array, min, max, args.null_proportion)); + auto num_records_in_array = args.num_records / num_chunks; + for (int64_t j = 0; j < num_chunks; ++j) { + chunks.push_back( + rand.Int64(num_records_in_array, min_value, max_value, args.null_proportion)); } - ASSIGN_OR_ABORT(auto chunked_array, ChunkedArray::Make(arrays, int64())); - columns.push_back(chunked_array); + ASSIGN_OR_ABORT(auto chunked_array, ChunkedArray::Make(chunks, int64())); + data.columns.push_back(chunked_array); + } + + data.schema = schema(fields); + return data; +} + +static void RecordBatchSortIndicesInt64(benchmark::State& state, int64_t min, + int64_t max) { + RecordBatchSortIndicesArgs args(state); + + auto data = MakeBatchOrTableBenchmarkDataInt64(args, /*num_chunks=*/1, min, max); + ArrayVector columns; + for (const auto& chunked : data.columns) { + ARROW_CHECK_EQ(chunked->num_chunks(), 1); + columns.push_back(chunked->chunk(0)); } - auto table = Table::Make(schema(fields), columns, args.num_records); - SortOptions options(sort_keys); - TableSortIndicesBenchmark(state, table, options); + auto batch = RecordBatch::Make(data.schema, args.num_records, columns); + SortOptions options(data.sort_keys); + DatumSortIndicesBenchmark(state, Datum(*batch), options); +} + +static void TableSortIndicesInt64(benchmark::State& state, int64_t min, int64_t max) { + TableSortIndicesArgs args(state); + + auto data = MakeBatchOrTableBenchmarkDataInt64(args, args.num_chunks, min, max); + auto table = Table::Make(data.schema, data.columns, args.num_records); + SortOptions options(data.sort_keys); + DatumSortIndicesBenchmark(state, Datum(*table), options); +} + +static void RecordBatchSortIndicesInt64Narrow(benchmark::State& state) { + RecordBatchSortIndicesInt64(state, -100, 100); +} + +static void RecordBatchSortIndicesInt64Wide(benchmark::State& state) { + RecordBatchSortIndicesInt64(state, std::numeric_limits::min(), + std::numeric_limits::max()); } static void TableSortIndicesInt64Narrow(benchmark::State& state) { @@ -180,28 +227,40 @@ BENCHMARK(ArraySortIndicesInt64Narrow) ->Apply(RegressionSetArgs) ->Args({1 << 20, 100}) ->Args({1 << 23, 100}) - ->MinTime(1.0) ->Unit(benchmark::TimeUnit::kNanosecond); BENCHMARK(ArraySortIndicesInt64Wide) ->Apply(RegressionSetArgs) ->Args({1 << 20, 100}) ->Args({1 << 23, 100}) - ->MinTime(1.0) ->Unit(benchmark::TimeUnit::kNanosecond); BENCHMARK(ChunkedArraySortIndicesInt64Narrow) ->Apply(RegressionSetArgs) ->Args({1 << 20, 100}) ->Args({1 << 23, 100}) - ->MinTime(1.0) ->Unit(benchmark::TimeUnit::kNanosecond); BENCHMARK(ChunkedArraySortIndicesInt64Wide) ->Apply(RegressionSetArgs) ->Args({1 << 20, 100}) ->Args({1 << 23, 100}) - ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); + +BENCHMARK(RecordBatchSortIndicesInt64Narrow) + ->ArgsProduct({ + {1 << 20}, // the number of records + {100, 0}, // inverse null proportion + {16, 8, 2, 1}, // the number of columns + }) + ->Unit(benchmark::TimeUnit::kNanosecond); + +BENCHMARK(RecordBatchSortIndicesInt64Wide) + ->ArgsProduct({ + {1 << 20}, // the number of records + {100, 0}, // inverse null proportion + {16, 8, 2, 1}, // the number of columns + }) ->Unit(benchmark::TimeUnit::kNanosecond); BENCHMARK(TableSortIndicesInt64Narrow) @@ -211,7 +270,6 @@ BENCHMARK(TableSortIndicesInt64Narrow) {16, 8, 2, 1}, // the number of columns {32, 4, 1}, // the number of chunks }) - ->MinTime(1.0) ->Unit(benchmark::TimeUnit::kNanosecond); BENCHMARK(TableSortIndicesInt64Wide) @@ -221,7 +279,6 @@ BENCHMARK(TableSortIndicesInt64Wide) {16, 8, 2, 1}, // the number of columns {32, 4, 1}, // the number of chunks }) - ->MinTime(1.0) ->Unit(benchmark::TimeUnit::kNanosecond); } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 4c42cffc80c..0c9cad508ef 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -63,6 +63,9 @@ TypeToDataType() { return time64(TimeUnit::NANO); } +// ---------------------------------------------------------------------- +// Tests for NthToIndices + template class NthComparator { public: @@ -226,6 +229,32 @@ class Random : public RandomImpl { } }; +template <> +class Random : public RandomImpl { + using CType = float; + + public: + explicit Random(random::SeedType seed) : RandomImpl(seed) {} + + std::shared_ptr Generate(uint64_t count, double null_prob, double nan_prob = 0) { + return generator.Float32(count, std::numeric_limits::min(), + std::numeric_limits::max(), null_prob, nan_prob); + } +}; + +template <> +class Random : public RandomImpl { + using CType = double; + + public: + explicit Random(random::SeedType seed) : RandomImpl(seed) {} + + std::shared_ptr Generate(uint64_t count, double null_prob, double nan_prob = 0) { + return generator.Float64(count, std::numeric_limits::min(), + std::numeric_limits::max(), null_prob, nan_prob); + } +}; + template <> class Random : public RandomImpl { public: @@ -267,24 +296,41 @@ TYPED_TEST(TestNthToIndicesRandom, RandomValues) { } } -using arrow::internal::checked_pointer_cast; +// ---------------------------------------------------------------------- +// Tests for SortToIndices + +template +void AssertSortIndices(const std::shared_ptr& input, SortOrder order, + const std::shared_ptr& expected) { + ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, order)); + ASSERT_OK(actual->ValidateFull()); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +template +void AssertSortIndices(const std::shared_ptr& input, const SortOptions& options, + const std::shared_ptr& expected) { + ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(Datum(*input), options)); + ASSERT_OK(actual->ValidateFull()); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +// `Options` may be both SortOptions or SortOrder +template +void AssertSortIndices(const std::shared_ptr& input, Options&& options, + const std::string& expected) { + AssertSortIndices(input, std::forward(options), + ArrayFromJSON(uint64(), expected)); +} template class TestArraySortIndicesKernel : public TestBase { - private: - void AssertArraysSortIndices(const std::shared_ptr values, SortOrder order, - const std::shared_ptr expected) { - ASSERT_OK_AND_ASSIGN(std::shared_ptr actual, SortIndices(*values, order)); - ASSERT_OK(actual->ValidateFull()); - AssertArraysEqual(*expected, *actual); - } - - protected: + public: virtual void AssertSortIndices(const std::string& values, SortOrder order, const std::string& expected) { auto type = TypeToDataType(); - AssertArraysSortIndices(ArrayFromJSON(type, values), order, - ArrayFromJSON(uint64(), expected)); + arrow::compute::AssertSortIndices(ArrayFromJSON(type, values), order, + ArrayFromJSON(uint64(), expected)); } virtual void AssertSortIndices(const std::string& values, const std::string& expected) { @@ -494,19 +540,7 @@ TYPED_TEST(TestArraySortIndicesKernelRandomCompare, SortRandomValuesCompare) { } // Test basic cases for chunked array. -class TestChunkedArraySortIndices : public ::testing::Test { - protected: - void AssertSortIndices(const std::shared_ptr chunked_array, - SortOrder order, const std::shared_ptr expected) { - ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*chunked_array, order)); - AssertArraysEqual(*expected, *actual, /*verbose=*/true); - } - - void AssertSortIndices(const std::shared_ptr chunked_array, - SortOrder order, const std::string expected) { - AssertSortIndices(chunked_array, order, ArrayFromJSON(uint64(), expected)); - } -}; +class TestChunkedArraySortIndices : public ::testing::Test {}; TEST_F(TestChunkedArraySortIndices, Null) { auto chunked_array = ChunkedArrayFromJSON(uint8(), { @@ -514,8 +548,8 @@ TEST_F(TestChunkedArraySortIndices, Null) { "[3, null, 2]", "[1]", }); - this->AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 5, 4, 2, 0, 3]"); - this->AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 4, 1, 5, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 5, 4, 2, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 4, 1, 5, 0, 3]"); } TEST_F(TestChunkedArraySortIndices, NaN) { @@ -524,8 +558,8 @@ TEST_F(TestChunkedArraySortIndices, NaN) { "[3, null, NaN]", "[NaN, 1]", }); - this->AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 6, 2, 4, 5, 0, 3]"); - this->AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 1, 6, 4, 5, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 6, 2, 4, 5, 0, 3]"); + AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 1, 6, 4, 5, 0, 3]"); } // Tests for temporal types @@ -543,8 +577,8 @@ TYPED_TEST(TestChunkedArraySortIndicesForTemporal, NoNull) { "[3, 2, 1]", "[5, 0]", }); - this->AssertSortIndices(chunked_array, SortOrder::Ascending, "[0, 6, 1, 4, 3, 2, 5]"); - this->AssertSortIndices(chunked_array, SortOrder::Descending, "[5, 2, 3, 1, 4, 0, 6]"); + AssertSortIndices(chunked_array, SortOrder::Ascending, "[0, 6, 1, 4, 3, 2, 5]"); + AssertSortIndices(chunked_array, SortOrder::Descending, "[5, 2, 3, 1, 4, 0, 6]"); } // Base class for testing against random chunked array. @@ -622,57 +656,206 @@ class TestChunkedArrayRandomNarrow : public TestChunkedArrayRandomBase { TYPED_TEST_SUITE(TestChunkedArrayRandomNarrow, IntegralArrowTypes); TYPED_TEST(TestChunkedArrayRandomNarrow, SortIndices) { this->TestSortIndices(1000); } -// Test basic cases for table. -class TestTableSortIndices : public ::testing::Test { - protected: - void AssertSortIndices(const std::shared_ptr
table, const SortOptions& options, - const std::shared_ptr expected) { - ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(Datum(*table), options)); - AssertArraysEqual(*expected, *actual); - } +// Test basic cases for record batch. +class TestRecordBatchSortIndices : public ::testing::Test {}; - void AssertSortIndices(const std::shared_ptr
table, const SortOptions& options, - const std::string expected) { - AssertSortIndices(table, options, ArrayFromJSON(uint64(), expected)); - } -}; +TEST_F(TestRecordBatchSortIndices, NoNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + SortOptions options( + {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + + auto batch = RecordBatchFromJSON(schema, + R"([{"a": 3, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": 4}, + {"a": 0, "b": 6}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5}, + {"a": 1, "b": 3} + ])"); + AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); +} + +TEST_F(TestRecordBatchSortIndices, Null) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + SortOptions options( + {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + + auto batch = RecordBatchFromJSON(schema, + R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"); + AssertSortIndices(batch, options, "[5, 1, 4, 2, 0, 3]"); +} + +TEST_F(TestRecordBatchSortIndices, NaN) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + SortOptions options( + {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + + auto batch = RecordBatchFromJSON(schema, + R"([{"a": 3, "b": 5}, + {"a": 1, "b": NaN}, + {"a": 3, "b": 4}, + {"a": 0, "b": 6}, + {"a": NaN, "b": 5}, + {"a": NaN, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"); + AssertSortIndices(batch, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); +} + +TEST_F(TestRecordBatchSortIndices, NaNAndNull) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + SortOptions options( + {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + + auto batch = RecordBatchFromJSON(schema, + R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": NaN, "b": null}, + {"a": NaN, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"); + AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); +} + +TEST_F(TestRecordBatchSortIndices, MoreTypes) { + auto schema = ::arrow::schema({ + {field("a", timestamp(TimeUnit::MICRO))}, + {field("b", large_utf8())}, + }); + SortOptions options( + {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + + auto batch = RecordBatchFromJSON(schema, + R"([{"a": 3, "b": "05"}, + {"a": 1, "b": "03"}, + {"a": 3, "b": "04"}, + {"a": 0, "b": "06"}, + {"a": 2, "b": "05"}, + {"a": 1, "b": "05"} + ])"); + AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]"); +} + +// Test basic cases for table. +class TestTableSortIndices : public ::testing::Test {}; TEST_F(TestTableSortIndices, Null) { - auto table = TableFromJSON(schema({ - {field("a", uint8())}, - {field("b", uint8())}, - }), - {"[" - "{\"a\": null, \"b\": 5}," - "{\"a\": 1, \"b\": 3}," - "{\"a\": 3, \"b\": null}," - "{\"a\": null, \"b\": null}," - "{\"a\": 2, \"b\": 5}," - "{\"a\": 1, \"b\": 5}" - "]"}); + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); SortOptions options( {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - this->AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]"); + std::shared_ptr
table; + + table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"}); + AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]"); + + // Same data, several chunks + table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"}); + AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]"); } TEST_F(TestTableSortIndices, NaN) { - auto table = TableFromJSON(schema({ - {field("a", float32())}, - {field("b", float32())}, - }), - {"[" - "{\"a\": null, \"b\": 5}," - "{\"a\": 1, \"b\": 3}," - "{\"a\": 3, \"b\": null}," - "{\"a\": null, \"b\": null}," - "{\"a\": NaN, \"b\": null}," - "{\"a\": NaN, \"b\": NaN}," - "{\"a\": NaN, \"b\": 5}," - "{\"a\": 1, \"b\": 5}" - "]"}); + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); SortOptions options( {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - this->AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); + std::shared_ptr
table; + table = TableFromJSON(schema, {R"([{"a": 3, "b": 5}, + {"a": 1, "b": NaN}, + {"a": 3, "b": 4}, + {"a": 0, "b": 6}, + {"a": NaN, "b": 5}, + {"a": NaN, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"}); + AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); + + // Same data, several chunks + table = TableFromJSON(schema, {R"([{"a": 3, "b": 5}, + {"a": 1, "b": NaN}, + {"a": 3, "b": 4}, + {"a": 0, "b": 6} + ])", + R"([{"a": NaN, "b": 5}, + {"a": NaN, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"}); + AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); +} + +TEST_F(TestTableSortIndices, NaNAndNull) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + SortOptions options( + {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + std::shared_ptr
table; + table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": NaN, "b": null}, + {"a": NaN, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"}); + AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); + + // Same data, several chunks + table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null} + ])", + R"([{"a": NaN, "b": null}, + {"a": NaN, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"}); + AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); } // Tests for temporal types @@ -701,7 +884,7 @@ TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) { "]"}); SortOptions options( {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); - this->AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]"); + AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]"); } // For random table tests. @@ -733,7 +916,8 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { if (lhs_array_->IsNull(lhs_index_)) return false; status_ = lhs_array_->type()->Accept(this); if (compared_ == 0) continue; - if (pair.second == SortOrder::Ascending) { + // If either value is NaN, it must sort after the other regardless of order + if (pair.second == SortOrder::Ascending || lhs_isnan_ || rhs_isnan_) { return compared_ < 0; } else { return compared_ > 0; @@ -791,11 +975,14 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { auto lhs_value = checked_cast(lhs_array_)->GetView(lhs_index_); auto rhs_value = checked_cast(rhs_array_)->GetView(rhs_index_); if (is_floating_type::value) { - const bool lhs_isnan = lhs_value != lhs_value; - const bool rhs_isnan = rhs_value != rhs_value; - if (lhs_isnan && rhs_isnan) return 0; - if (rhs_isnan) return 1; - if (lhs_isnan) return -1; + lhs_isnan_ = lhs_value != lhs_value; + rhs_isnan_ = rhs_value != rhs_value; + if (lhs_isnan_ && rhs_isnan_) return 0; + // NaN is considered greater than non-NaN + if (rhs_isnan_) return -1; + if (lhs_isnan_) return 1; + } else { + lhs_isnan_ = rhs_isnan_ = false; } if (lhs_value == rhs_value) { return 0; @@ -814,6 +1001,7 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { int64_t rhs_; const Array* rhs_array_; int64_t rhs_index_; + bool lhs_isnan_, rhs_isnan_; int compared_; Status status_; }; @@ -826,8 +1014,8 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { for (int i = 1; i < table.num_rows(); i++) { uint64_t lhs = offsets.Value(i - 1); uint64_t rhs = offsets.Value(i); - ASSERT_TRUE(comparator(lhs, rhs)); ASSERT_OK(comparator.status()); + ASSERT_TRUE(comparator(lhs, rhs)) << "lhs = " << lhs << ", rhs = " << rhs; } } }; @@ -851,15 +1039,15 @@ TEST_P(TestTableSortIndicesRandom, Sort) { const auto length = 200; std::vector> columns = { Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, null_probability), + Random(seed).Generate(length, 0.0), Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, null_probability), + Random(seed).Generate(length, 0.0), + Random(seed).Generate(length, 0.0), Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, null_probability), + Random(seed).Generate(length, 0.0), Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, null_probability), - Random(seed).Generate(length, null_probability), + Random(seed).Generate(length, null_probability, 1 - null_probability), + Random(seed).Generate(length, 0.0, null_probability), Random(seed).Generate(length, null_probability), }; const auto table = Table::Make(schema(fields), columns, length); @@ -884,6 +1072,13 @@ TEST_P(TestTableSortIndicesRandom, Sort) { ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*chunked_table), options)); Validate(*table, options, *checked_pointer_cast(offsets)); } + // Also validate RecordBatch sorting + TableBatchReader reader(*table); + RecordBatchVector batches; + ASSERT_OK(reader.ReadAll(&batches)); + ASSERT_EQ(batches.size(), 1); + ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*batches[0]), options)); + Validate(*table, options, *checked_pointer_cast(offsets)); } INSTANTIATE_TEST_SUITE_P(NoNull, TestTableSortIndicesRandom, diff --git a/cpp/src/arrow/testing/random.cc b/cpp/src/arrow/testing/random.cc index 5d2525041c1..92ac5f8d07d 100644 --- a/cpp/src/arrow/testing/random.cc +++ b/cpp/src/arrow/testing/random.cc @@ -18,8 +18,10 @@ #include "arrow/testing/random.h" #include +#include #include #include +#include #include #include @@ -46,20 +48,48 @@ namespace { template struct GenerateOptions { - GenerateOptions(SeedType seed, ValueType min, ValueType max, double probability) - : min_(min), max_(max), seed_(seed), probability_(probability) {} + GenerateOptions(SeedType seed, ValueType min, ValueType max, double probability, + double nan_probability = 0.0) + : min_(min), + max_(max), + seed_(seed), + probability_(probability), + nan_probability_(nan_probability) {} void GenerateData(uint8_t* buffer, size_t n) { GenerateTypedData(reinterpret_cast(buffer), n); } - void GenerateTypedData(ValueType* data, size_t n) { + template + typename std::enable_if::value>::type GenerateTypedData( + V* data, size_t n) { + GenerateTypedDataNoNan(data, n); + } + + template + typename std::enable_if::value>::type GenerateTypedData( + V* data, size_t n) { + if (nan_probability_ == 0.0) { + GenerateTypedDataNoNan(data, n); + return; + } std::default_random_engine rng(seed_++); DistributionType dist(min_, max_); + std::bernoulli_distribution nan_dist(nan_probability_); + const ValueType nan_value = std::numeric_limits::quiet_NaN(); // A static cast is required due to the int16 -> int8 handling. - std::generate(data, data + n, - [&dist, &rng] { return static_cast(dist(rng)); }); + std::generate(data, data + n, [&] { + return nan_dist(rng) ? nan_value : static_cast(dist(rng)); + }); + } + + void GenerateTypedDataNoNan(ValueType* data, size_t n) { + std::default_random_engine rng(seed_++); + DistributionType dist(min_, max_); + + // A static cast is required due to the int16 -> int8 handling. + std::generate(data, data + n, [&] { return static_cast(dist(rng)); }); } void GenerateBitmap(uint8_t* buffer, size_t n, int64_t* null_count) { @@ -82,6 +112,7 @@ struct GenerateOptions { ValueType max_; SeedType seed_; double probability_; + double nan_probability_; }; } // namespace @@ -170,14 +201,23 @@ PRIMITIVE_RAND_INTEGER_IMPL(Int64, int64_t, Int64Type) // Generate 16bit values for half-float PRIMITIVE_RAND_INTEGER_IMPL(Float16, int16_t, HalfFloatType) -#define PRIMITIVE_RAND_FLOAT_IMPL(Name, CType, ArrowType) \ - PRIMITIVE_RAND_IMPL(Name, CType, ArrowType, std::uniform_real_distribution) +std::shared_ptr RandomArrayGenerator::Float32(int64_t size, float min, float max, + double null_probability, + double nan_probability) { + using OptionType = GenerateOptions>; + OptionType options(seed(), min, max, null_probability, nan_probability); + return GenerateNumericArray(size, options); +} -PRIMITIVE_RAND_FLOAT_IMPL(Float32, float, FloatType) -PRIMITIVE_RAND_FLOAT_IMPL(Float64, double, DoubleType) +std::shared_ptr RandomArrayGenerator::Float64(int64_t size, double min, double max, + double null_probability, + double nan_probability) { + using OptionType = GenerateOptions>; + OptionType options(seed(), min, max, null_probability, nan_probability); + return GenerateNumericArray(size, options); +} #undef PRIMITIVE_RAND_INTEGER_IMPL -#undef PRIMITIVE_RAND_FLOAT_IMPL #undef PRIMITIVE_RAND_IMPL template diff --git a/cpp/src/arrow/testing/random.h b/cpp/src/arrow/testing/random.h index a1a16c07cf2..f874fae15b7 100644 --- a/cpp/src/arrow/testing/random.h +++ b/cpp/src/arrow/testing/random.h @@ -165,10 +165,11 @@ class ARROW_TESTING_EXPORT RandomArrayGenerator { /// \param[in] min the lower bound of the uniform distribution /// \param[in] max the upper bound of the uniform distribution /// \param[in] null_probability the probability of a row being null + /// \param[in] nan_probability the probability of a row being NaN /// /// \return a generated Array std::shared_ptr Float32(int64_t size, float min, float max, - double null_probability = 0); + double null_probability = 0, double nan_probability = 0); /// \brief Generate a random DoubleArray /// @@ -176,10 +177,11 @@ class ARROW_TESTING_EXPORT RandomArrayGenerator { /// \param[in] min the lower bound of the uniform distribution /// \param[in] max the upper bound of the uniform distribution /// \param[in] null_probability the probability of a row being null + /// \param[in] nan_probability the probability of a row being NaN /// /// \return a generated Array std::shared_ptr Float64(int64_t size, double min, double max, - double null_probability = 0); + double null_probability = 0, double nan_probability = 0); template std::shared_ptr Numeric(int64_t size, CType min, CType max,