Skip to content

Commit b56b91e

Browse files
authored
GH-18481: [C++] prefer casting literal over casting field ref (#15180)
I ran into this problem while trying to work out partition pruning in the new scan node. I feel like this is a somewhat naive approach but it seems to work. I think it would fail if a `DispatchBest` existed where a n-ary kernel existed with non-equal types. For example, if there was a function foo(int8, int32) and it had a dispatch best of some kind. Authored-by: Weston Pace <weston.pace@gmail.com> Signed-off-by: Weston Pace <weston.pace@gmail.com>
1 parent 838d0da commit b56b91e

7 files changed

Lines changed: 227 additions & 15 deletions

File tree

cpp/src/arrow/compute/exec/expression.cc

Lines changed: 168 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,11 @@ bool Expression::Equals(const Expression& other) const {
221221
}
222222

223223
if (auto lit = literal()) {
224-
return lit->Equals(*other.literal());
224+
// The scalar NaN is not equal to the scalar NaN but the literal NaN
225+
// is equal to the literal NaN (e.g. the expressions are equal even if
226+
// the values are not)
227+
EqualOptions equal_options = EqualOptions::Defaults().nans_equal(true);
228+
return lit->scalar()->Equals(other.literal()->scalar(), equal_options);
225229
}
226230

227231
if (auto ref = field_ref()) {
@@ -368,6 +372,158 @@ bool Expression::IsSatisfiable() const {
368372

369373
namespace {
370374

375+
TypeHolder SmallestTypeFor(const arrow::Datum& value) {
376+
switch (value.type()->id()) {
377+
case Type::INT8:
378+
return int8();
379+
case Type::UINT8:
380+
return uint8();
381+
case Type::INT16: {
382+
int16_t i16 = value.scalar_as<Int16Scalar>().value;
383+
if (i16 <= std::numeric_limits<int8_t>::max() &&
384+
i16 >= std::numeric_limits<int8_t>::min()) {
385+
return int8();
386+
}
387+
return int16();
388+
}
389+
case Type::UINT16: {
390+
uint16_t ui16 = value.scalar_as<UInt16Scalar>().value;
391+
if (ui16 <= std::numeric_limits<uint8_t>::max()) {
392+
return uint8();
393+
}
394+
return uint16();
395+
}
396+
case Type::INT32: {
397+
int32_t i32 = value.scalar_as<Int32Scalar>().value;
398+
if (i32 <= std::numeric_limits<int8_t>::max() &&
399+
i32 >= std::numeric_limits<int8_t>::min()) {
400+
return int8();
401+
}
402+
if (i32 <= std::numeric_limits<int16_t>::max() &&
403+
i32 >= std::numeric_limits<int16_t>::min()) {
404+
return int16();
405+
}
406+
return int32();
407+
}
408+
case Type::UINT32: {
409+
uint32_t ui32 = value.scalar_as<UInt32Scalar>().value;
410+
if (ui32 <= std::numeric_limits<uint8_t>::max()) {
411+
return uint8();
412+
}
413+
if (ui32 <= std::numeric_limits<uint16_t>::max()) {
414+
return uint16();
415+
}
416+
return uint32();
417+
}
418+
case Type::INT64: {
419+
int64_t i64 = value.scalar_as<Int64Scalar>().value;
420+
if (i64 <= std::numeric_limits<int8_t>::max() &&
421+
i64 >= std::numeric_limits<int8_t>::min()) {
422+
return int8();
423+
}
424+
if (i64 <= std::numeric_limits<int16_t>::max() &&
425+
i64 >= std::numeric_limits<int16_t>::min()) {
426+
return int16();
427+
}
428+
if (i64 <= std::numeric_limits<int32_t>::max() &&
429+
i64 >= std::numeric_limits<int32_t>::min()) {
430+
return int32();
431+
}
432+
return int64();
433+
}
434+
case Type::UINT64: {
435+
uint64_t ui64 = value.scalar_as<UInt64Scalar>().value;
436+
if (ui64 <= std::numeric_limits<uint8_t>::max()) {
437+
return uint8();
438+
}
439+
if (ui64 <= std::numeric_limits<uint16_t>::max()) {
440+
return uint16();
441+
}
442+
if (ui64 <= std::numeric_limits<uint32_t>::max()) {
443+
return uint32();
444+
}
445+
return uint64();
446+
}
447+
case Type::DOUBLE: {
448+
double doub = value.scalar_as<DoubleScalar>().value;
449+
if (!std::isfinite(doub)) {
450+
// Special values can be float
451+
return float32();
452+
}
453+
// Test if float representation is the same
454+
if (static_cast<double>(static_cast<float>(doub)) == doub) {
455+
return float32();
456+
}
457+
return float64();
458+
}
459+
case Type::LARGE_STRING: {
460+
if (value.scalar_as<LargeStringScalar>().value->size() <=
461+
std::numeric_limits<int32_t>::max()) {
462+
return utf8();
463+
}
464+
return large_utf8();
465+
}
466+
case Type::LARGE_BINARY:
467+
if (value.scalar_as<LargeBinaryScalar>().value->size() <=
468+
std::numeric_limits<int32_t>::max()) {
469+
return binary();
470+
}
471+
return large_binary();
472+
case Type::TIMESTAMP: {
473+
const auto& ts_type = checked_pointer_cast<TimestampType>(value.type());
474+
uint64_t ts = value.scalar_as<TimestampScalar>().value;
475+
switch (ts_type->unit()) {
476+
case TimeUnit::SECOND:
477+
return value.type();
478+
case TimeUnit::MILLI:
479+
if (ts % 1000 == 0) {
480+
return timestamp(TimeUnit::SECOND);
481+
}
482+
return value.type();
483+
case TimeUnit::MICRO:
484+
if (ts % 1000000 == 0) {
485+
return timestamp(TimeUnit::SECOND);
486+
}
487+
if (ts % 1000 == 0) {
488+
return timestamp(TimeUnit::MILLI);
489+
}
490+
return value.type();
491+
case TimeUnit::NANO:
492+
if (ts % 1000000000 == 0) {
493+
return timestamp(TimeUnit::SECOND);
494+
}
495+
if (ts % 1000000 == 0) {
496+
return timestamp(TimeUnit::MILLI);
497+
}
498+
if (ts % 1000 == 0) {
499+
return timestamp(TimeUnit::MICRO);
500+
}
501+
return value.type();
502+
default:
503+
return value.type();
504+
}
505+
}
506+
default:
507+
return value.type();
508+
}
509+
}
510+
511+
inline std::vector<TypeHolder> GetTypesWithSmallestLiteralRepresentation(
512+
const std::vector<Expression>& exprs) {
513+
std::vector<TypeHolder> types(exprs.size());
514+
for (size_t i = 0; i < exprs.size(); ++i) {
515+
DCHECK(exprs[i].IsBound());
516+
if (const Datum* literal = exprs[i].literal()) {
517+
if (literal->is_scalar()) {
518+
types[i] = SmallestTypeFor(*literal);
519+
}
520+
} else {
521+
types[i] = exprs[i].type();
522+
}
523+
}
524+
return types;
525+
}
526+
371527
// Produce a bound Expression from unbound Call and bound arguments.
372528
Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_casts,
373529
compute::ExecContext* exec_context) {
@@ -377,9 +533,18 @@ Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_
377533
std::vector<TypeHolder> types = GetTypes(call.arguments);
378534
ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context));
379535

380-
if (!insert_implicit_casts) {
381-
ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchExact(types));
536+
// First try and bind exactly
537+
Result<const Kernel*> maybe_exact_match = call.function->DispatchExact(types);
538+
if (maybe_exact_match.ok()) {
539+
call.kernel = *maybe_exact_match;
382540
} else {
541+
if (!insert_implicit_casts) {
542+
return maybe_exact_match.status();
543+
}
544+
// If exact binding fails, and we are allowed to cast, then prefer casting literals
545+
// first. Since DispatchBest generally prefers up-casting the best way to do this is
546+
// first down-cast the literals as much as possible
547+
types = GetTypesWithSmallestLiteralRepresentation(call.arguments);
383548
ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types));
384549

385550
for (size_t i = 0; i < types.size(); ++i) {

cpp/src/arrow/compute/exec/expression_test.cc

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ const std::shared_ptr<Schema> kBoringSchema = schema({
6060
field("dict_i32", dictionary(int32(), int32())),
6161
field("ts_ns", timestamp(TimeUnit::NANO)),
6262
field("ts_s", timestamp(TimeUnit::SECOND)),
63+
field("binary", binary()),
6364
});
6465

6566
#define EXPECT_OK ARROW_EXPECT_OK
@@ -330,6 +331,23 @@ TEST(Expression, Equality) {
330331
EXPECT_EQ(literal(1), literal(1));
331332
EXPECT_NE(literal(1), literal(2));
332333

334+
// NaN literals (of the same type) should be equal. This allows, for example,
335+
// the expression x == NaN to equal itself.
336+
auto double_nan_literal = literal(std::numeric_limits<double>::quiet_NaN());
337+
auto float_nan_literal = literal(std::numeric_limits<float>::quiet_NaN());
338+
EXPECT_EQ(double_nan_literal, double_nan_literal);
339+
EXPECT_NE(double_nan_literal, float_nan_literal);
340+
// The literals may be equal but the values should not be
341+
Expression nans_eq = equal(double_nan_literal, double_nan_literal);
342+
ASSERT_OK_AND_ASSIGN(nans_eq, nans_eq.Bind(*kBoringSchema));
343+
ASSERT_OK_AND_ASSIGN(Datum nans_eq_rsp, ExecuteScalarExpression(nans_eq, ExecBatch()));
344+
EXPECT_FALSE(nans_eq_rsp.scalar_as<BooleanScalar>().value);
345+
if (std::numeric_limits<double>::has_signaling_NaN) {
346+
// We intentionally do not care about signaling and may even discard it on conversion.
347+
EXPECT_EQ(literal(std::numeric_limits<double>::quiet_NaN()),
348+
literal(std::numeric_limits<double>::signaling_NaN()));
349+
}
350+
333351
EXPECT_EQ(field_ref("a"), field_ref("a"));
334352
EXPECT_NE(field_ref("a"), field_ref("b"));
335353
EXPECT_NE(field_ref("a"), literal(2));
@@ -593,8 +611,36 @@ TEST(Expression, BindWithImplicitCasts) {
593611
ExpectBindsTo(cmp(field_ref("dict_str"), field_ref("str")),
594612
cmp(cast(field_ref("dict_str"), utf8()), field_ref("str")));
595613

614+
// Should prefer the literal
615+
ExpectBindsTo(cmp(field_ref("dict_i32"), literal(int64_t(4))),
616+
cmp(field_ref("dict_i32"), literal(int32_t(4))));
596617
ExpectBindsTo(cmp(field_ref("dict_i32"), literal(int64_t(4))),
597-
cmp(cast(field_ref("dict_i32"), int64()), literal(int64_t(4))));
618+
cmp(field_ref("dict_i32"), literal(int32_t(4))));
619+
ExpectBindsTo(cmp(field_ref("ts_s"),
620+
literal(std::make_shared<TimestampScalar>(0, TimeUnit::NANO))),
621+
cmp(field_ref("ts_s"),
622+
literal(std::make_shared<TimestampScalar>(0, TimeUnit::SECOND))));
623+
ExpectBindsTo(
624+
cmp(field_ref("binary"), literal(std::make_shared<LargeBinaryScalar>("foo"))),
625+
cmp(field_ref("binary"), literal(std::make_shared<BinaryScalar>("foo"))));
626+
627+
// We will not implicitly cast a literal from signed to unsigned or vice versa
628+
ExpectBindsTo(cmp(field_ref("i8"), literal(uint8_t(4))),
629+
cmp(cast(field_ref("i8"), int16()), literal(int16_t(4))));
630+
ExpectBindsTo(cmp(field_ref("u32"), literal(int64_t(4))),
631+
cmp(cast(field_ref("u32"), int64()), literal(int64_t(4))));
632+
633+
// NaN / Inf can be float or double as needed
634+
ExpectBindsTo(
635+
cmp(field_ref("f32"), literal(std::numeric_limits<double>::quiet_NaN())),
636+
cmp(field_ref("f32"), literal(std::numeric_limits<float>::quiet_NaN())));
637+
ExpectBindsTo(cmp(field_ref("f32"), literal(std::numeric_limits<double>::infinity())),
638+
cmp(field_ref("f32"), literal(std::numeric_limits<float>::infinity())));
639+
640+
// Bit of an odd case, both fields are cast
641+
ExpectBindsTo(cmp(field_ref("i32"), literal(std::make_shared<DoubleScalar>(10.0))),
642+
cmp(cast(field_ref("i32"), float32()),
643+
literal(std::make_shared<FloatScalar>(10.0f))));
598644
}
599645

600646
compute::SetLookupOptions in_a{ArrayFromJSON(utf8(), R"(["a"])")};

cpp/src/arrow/compute/kernels/codegen_internal.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,11 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector<TypeHolder>
338338
const DataType& right_type = *(*types)[1];
339339
DCHECK(is_decimal(left_type.id()) || is_decimal(right_type.id()));
340340

341-
// decimal + float = float
342-
if (is_floating(left_type.id())) {
343-
(*types)[1] = (*types)[0];
344-
return Status::OK();
345-
} else if (is_floating(right_type.id())) {
346-
(*types)[0] = (*types)[1];
341+
// decimal + float64 = float64
342+
// decimal + float32 is roughly float64 + float32 so we choose float64
343+
if (is_floating(left_type.id()) || is_floating(right_type.id())) {
344+
(*types)[0] = float64();
345+
(*types)[1] = float64();
347346
return Status::OK();
348347
}
349348

cpp/src/arrow/compute/kernels/codegen_internal_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ TEST(TestDispatchBest, CastBinaryDecimalArgs) {
3434

3535
// Any float -> all float
3636
for (auto mode : modes) {
37-
args = {decimal128(3, 2), float64()};
37+
args = {decimal128(3, 2), float32(), float64()};
3838
ASSERT_OK(CastBinaryDecimalArgs(mode, &args));
3939
AssertTypeEqual(*args[0], *float64());
4040
AssertTypeEqual(*args[1], *float64());

cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,9 +1763,9 @@ TEST_F(TestBinaryArithmeticDecimal, DispatchBest) {
17631763
for (std::string suffix : {"", "_checked"}) {
17641764
name += suffix;
17651765

1766-
CheckDispatchBest(name, {decimal128(1, 0), float32()}, {float32(), float32()});
1766+
CheckDispatchBest(name, {decimal128(1, 0), float32()}, {float64(), float64()});
17671767
CheckDispatchBest(name, {decimal256(1, 0), float64()}, {float64(), float64()});
1768-
CheckDispatchBest(name, {float32(), decimal256(1, 0)}, {float32(), float32()});
1768+
CheckDispatchBest(name, {float32(), decimal256(1, 0)}, {float64(), float64()});
17691769
CheckDispatchBest(name, {float64(), decimal128(1, 0)}, {float64(), float64()});
17701770
}
17711771
}

cpp/src/arrow/engine/substrait/serde_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3757,7 +3757,7 @@ TEST(Substrait, NestedProjectWithMultiFieldExpressions) {
37573757

37583758
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
37593759

3760-
auto output_schema = schema({field("A", float64()), field("B", float64())});
3760+
auto output_schema = schema({field("A", float32()), field("B", float32())});
37613761
auto expected_table = TableFromJSON(output_schema, {R"([
37623762
[20, 20],
37633763
[30, 30],

r/tests/testthat/test-type.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,9 @@ test_that("infer_type() gets the right type for Expression", {
280280
expect_equal(y$type(), infer_type(y))
281281
expect_equal(infer_type(y), float64())
282282
expect_equal(add_xy$type(), infer_type(add_xy))
283-
expect_equal(infer_type(add_xy), float64())
283+
# even though 10 is a float64, arrow will clamp it to the narrowest
284+
# type that can exactly represent it when building expressions
285+
expect_equal(infer_type(add_xy), float32())
284286
})
285287

286288
test_that("infer_type() infers type for POSIXlt", {

0 commit comments

Comments
 (0)