@@ -221,7 +221,11 @@ bool Expression::Equals(const Expression& other) const {
221221 }
222222
223223 if (auto lit = literal ()) {
224- return lit->Equals (*other.literal ());
224+ // The scalar NaN is not equal to the scalar NaN but the literal NaN
225+ // is equal to the literal NaN (e.g. the expressions are equal even if
226+ // the values are not)
227+ EqualOptions equal_options = EqualOptions::Defaults ().nans_equal (true );
228+ return lit->scalar ()->Equals (other.literal ()->scalar (), equal_options);
225229 }
226230
227231 if (auto ref = field_ref ()) {
@@ -368,6 +372,158 @@ bool Expression::IsSatisfiable() const {
368372
369373namespace {
370374
375+ TypeHolder SmallestTypeFor (const arrow::Datum& value) {
376+ switch (value.type ()->id ()) {
377+ case Type::INT8:
378+ return int8 ();
379+ case Type::UINT8:
380+ return uint8 ();
381+ case Type::INT16: {
382+ int16_t i16 = value.scalar_as <Int16Scalar>().value ;
383+ if (i16 <= std::numeric_limits<int8_t >::max () &&
384+ i16 >= std::numeric_limits<int8_t >::min ()) {
385+ return int8 ();
386+ }
387+ return int16 ();
388+ }
389+ case Type::UINT16: {
390+ uint16_t ui16 = value.scalar_as <UInt16Scalar>().value ;
391+ if (ui16 <= std::numeric_limits<uint8_t >::max ()) {
392+ return uint8 ();
393+ }
394+ return uint16 ();
395+ }
396+ case Type::INT32: {
397+ int32_t i32 = value.scalar_as <Int32Scalar>().value ;
398+ if (i32 <= std::numeric_limits<int8_t >::max () &&
399+ i32 >= std::numeric_limits<int8_t >::min ()) {
400+ return int8 ();
401+ }
402+ if (i32 <= std::numeric_limits<int16_t >::max () &&
403+ i32 >= std::numeric_limits<int16_t >::min ()) {
404+ return int16 ();
405+ }
406+ return int32 ();
407+ }
408+ case Type::UINT32: {
409+ uint32_t ui32 = value.scalar_as <UInt32Scalar>().value ;
410+ if (ui32 <= std::numeric_limits<uint8_t >::max ()) {
411+ return uint8 ();
412+ }
413+ if (ui32 <= std::numeric_limits<uint16_t >::max ()) {
414+ return uint16 ();
415+ }
416+ return uint32 ();
417+ }
418+ case Type::INT64: {
419+ int64_t i64 = value.scalar_as <Int64Scalar>().value ;
420+ if (i64 <= std::numeric_limits<int8_t >::max () &&
421+ i64 >= std::numeric_limits<int8_t >::min ()) {
422+ return int8 ();
423+ }
424+ if (i64 <= std::numeric_limits<int16_t >::max () &&
425+ i64 >= std::numeric_limits<int16_t >::min ()) {
426+ return int16 ();
427+ }
428+ if (i64 <= std::numeric_limits<int32_t >::max () &&
429+ i64 >= std::numeric_limits<int32_t >::min ()) {
430+ return int32 ();
431+ }
432+ return int64 ();
433+ }
434+ case Type::UINT64: {
435+ uint64_t ui64 = value.scalar_as <UInt64Scalar>().value ;
436+ if (ui64 <= std::numeric_limits<uint8_t >::max ()) {
437+ return uint8 ();
438+ }
439+ if (ui64 <= std::numeric_limits<uint16_t >::max ()) {
440+ return uint16 ();
441+ }
442+ if (ui64 <= std::numeric_limits<uint32_t >::max ()) {
443+ return uint32 ();
444+ }
445+ return uint64 ();
446+ }
447+ case Type::DOUBLE: {
448+ double doub = value.scalar_as <DoubleScalar>().value ;
449+ if (!std::isfinite (doub)) {
450+ // Special values can be float
451+ return float32 ();
452+ }
453+ // Test if float representation is the same
454+ if (static_cast <double >(static_cast <float >(doub)) == doub) {
455+ return float32 ();
456+ }
457+ return float64 ();
458+ }
459+ case Type::LARGE_STRING: {
460+ if (value.scalar_as <LargeStringScalar>().value ->size () <=
461+ std::numeric_limits<int32_t >::max ()) {
462+ return utf8 ();
463+ }
464+ return large_utf8 ();
465+ }
466+ case Type::LARGE_BINARY:
467+ if (value.scalar_as <LargeBinaryScalar>().value ->size () <=
468+ std::numeric_limits<int32_t >::max ()) {
469+ return binary ();
470+ }
471+ return large_binary ();
472+ case Type::TIMESTAMP: {
473+ const auto & ts_type = checked_pointer_cast<TimestampType>(value.type ());
474+ uint64_t ts = value.scalar_as <TimestampScalar>().value ;
475+ switch (ts_type->unit ()) {
476+ case TimeUnit::SECOND:
477+ return value.type ();
478+ case TimeUnit::MILLI:
479+ if (ts % 1000 == 0 ) {
480+ return timestamp (TimeUnit::SECOND);
481+ }
482+ return value.type ();
483+ case TimeUnit::MICRO:
484+ if (ts % 1000000 == 0 ) {
485+ return timestamp (TimeUnit::SECOND);
486+ }
487+ if (ts % 1000 == 0 ) {
488+ return timestamp (TimeUnit::MILLI);
489+ }
490+ return value.type ();
491+ case TimeUnit::NANO:
492+ if (ts % 1000000000 == 0 ) {
493+ return timestamp (TimeUnit::SECOND);
494+ }
495+ if (ts % 1000000 == 0 ) {
496+ return timestamp (TimeUnit::MILLI);
497+ }
498+ if (ts % 1000 == 0 ) {
499+ return timestamp (TimeUnit::MICRO);
500+ }
501+ return value.type ();
502+ default :
503+ return value.type ();
504+ }
505+ }
506+ default :
507+ return value.type ();
508+ }
509+ }
510+
511+ inline std::vector<TypeHolder> GetTypesWithSmallestLiteralRepresentation (
512+ const std::vector<Expression>& exprs) {
513+ std::vector<TypeHolder> types (exprs.size ());
514+ for (size_t i = 0 ; i < exprs.size (); ++i) {
515+ DCHECK (exprs[i].IsBound ());
516+ if (const Datum* literal = exprs[i].literal ()) {
517+ if (literal->is_scalar ()) {
518+ types[i] = SmallestTypeFor (*literal);
519+ }
520+ } else {
521+ types[i] = exprs[i].type ();
522+ }
523+ }
524+ return types;
525+ }
526+
371527// Produce a bound Expression from unbound Call and bound arguments.
372528Result<Expression> BindNonRecursive (Expression::Call call, bool insert_implicit_casts,
373529 compute::ExecContext* exec_context) {
@@ -377,9 +533,18 @@ Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_
377533 std::vector<TypeHolder> types = GetTypes (call.arguments );
378534 ARROW_ASSIGN_OR_RAISE (call.function , GetFunction (call, exec_context));
379535
380- if (!insert_implicit_casts) {
381- ARROW_ASSIGN_OR_RAISE (call.kernel , call.function ->DispatchExact (types));
536+ // First try and bind exactly
537+ Result<const Kernel*> maybe_exact_match = call.function ->DispatchExact (types);
538+ if (maybe_exact_match.ok ()) {
539+ call.kernel = *maybe_exact_match;
382540 } else {
541+ if (!insert_implicit_casts) {
542+ return maybe_exact_match.status ();
543+ }
544+ // If exact binding fails, and we are allowed to cast, then prefer casting literals
545+ // first. Since DispatchBest generally prefers up-casting the best way to do this is
546+ // first down-cast the literals as much as possible
547+ types = GetTypesWithSmallestLiteralRepresentation (call.arguments );
383548 ARROW_ASSIGN_OR_RAISE (call.kernel , call.function ->DispatchBest (&types));
384549
385550 for (size_t i = 0 ; i < types.size (); ++i) {
0 commit comments