Skip to content

Commit 3062cd5

Browse files
committed
do type coercion in the simplify expression
1 parent bb7f0ac commit 3062cd5

10 files changed

Lines changed: 177 additions & 160 deletions

File tree

datafusion/core/src/execution/context.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,12 +1433,6 @@ impl SessionState {
14331433
// of applying other optimizations
14341434
Arc::new(SimplifyExpressions::new()),
14351435
Arc::new(PreCastLitInComparisonExpressions::new()),
1436-
// Do the type coercion first
1437-
// TODO: https://github.com/apache/arrow-datafusion/issues/3556
1438-
Arc::new(TypeCoercion::new()),
1439-
// The first simplify expression will fail, if the type is not right
1440-
// This simplify expression will done after the type coercion
1441-
Arc::new(SimplifyExpressions::new()),
14421436
Arc::new(DecorrelateWhereExists::new()),
14431437
Arc::new(DecorrelateWhereIn::new()),
14441438
Arc::new(ScalarSubqueryToJoin::new()),
@@ -1461,6 +1455,8 @@ impl SessionState {
14611455
// TODO: https://github.com/apache/arrow-datafusion/issues/3557
14621456
// remove this, after the issue fixed.
14631457
rules.push(Arc::new(TypeCoercion::new()));
1458+
// after the type coercion, can do simplify expression again
1459+
rules.push(Arc::new(SimplifyExpressions::new()));
14641460
rules.push(Arc::new(FilterPushDown::new()));
14651461
rules.push(Arc::new(LimitPushDown::new()));
14661462
rules.push(Arc::new(SingleDistinctToGroupBy::new()));

datafusion/core/tests/sql/aggregates.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> {
18341834
assert_eq!(results.len(), 1);
18351835

18361836
let expected = vec![
1837-
"+--------------+---------------------------+---------------------------+---------------------------+",
1838-
"| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) + Float64(2) | Float64(1) + AVG(test.c1) |",
1839-
"+--------------+---------------------------+---------------------------+---------------------------+",
1840-
"| 1.5 | 2.5 | 3.5 | 2.5 |",
1841-
"+--------------+---------------------------+---------------------------+---------------------------+",
1837+
"+--------------+-------------------------+-------------------------+-------------------------+",
1838+
"| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |",
1839+
"+--------------+-------------------------+-------------------------+-------------------------+",
1840+
"| 1.5 | 2.5 | 3.5 | 2.5 |",
1841+
"+--------------+-------------------------+-------------------------+-------------------------+",
18421842
];
18431843
assert_batches_sorted_eq!(expected, &results);
18441844

datafusion/core/tests/sql/decimal.rs

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -376,25 +376,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
376376
actual[0].schema().field(0).data_type()
377377
);
378378
let expected = vec![
379-
"+----------------------------------------------------+",
380-
"| decimal_simple.c1 + Decimal128(Some(1000000),27,6) |",
381-
"+----------------------------------------------------+",
382-
"| 1.000010 |",
383-
"| 1.000020 |",
384-
"| 1.000020 |",
385-
"| 1.000030 |",
386-
"| 1.000030 |",
387-
"| 1.000030 |",
388-
"| 1.000040 |",
389-
"| 1.000040 |",
390-
"| 1.000040 |",
391-
"| 1.000040 |",
392-
"| 1.000050 |",
393-
"| 1.000050 |",
394-
"| 1.000050 |",
395-
"| 1.000050 |",
396-
"| 1.000050 |",
397-
"+----------------------------------------------------+",
379+
"+------------------------------+",
380+
"| decimal_simple.c1 + Int64(1) |",
381+
"+------------------------------+",
382+
"| 1.000010 |",
383+
"| 1.000020 |",
384+
"| 1.000020 |",
385+
"| 1.000030 |",
386+
"| 1.000030 |",
387+
"| 1.000030 |",
388+
"| 1.000040 |",
389+
"| 1.000040 |",
390+
"| 1.000040 |",
391+
"| 1.000040 |",
392+
"| 1.000050 |",
393+
"| 1.000050 |",
394+
"| 1.000050 |",
395+
"| 1.000050 |",
396+
"| 1.000050 |",
397+
"+------------------------------+",
398398
];
399399
assert_batches_eq!(expected, &actual);
400400
// array decimal(10,6) + array decimal(12,7) => decimal(13,7)
@@ -434,25 +434,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
434434
actual[0].schema().field(0).data_type()
435435
);
436436
let expected = vec![
437-
"+----------------------------------------------------+",
438-
"| decimal_simple.c1 - Decimal128(Some(1000000),27,6) |",
439-
"+----------------------------------------------------+",
440-
"| -0.999990 |",
441-
"| -0.999980 |",
442-
"| -0.999980 |",
443-
"| -0.999970 |",
444-
"| -0.999970 |",
445-
"| -0.999970 |",
446-
"| -0.999960 |",
447-
"| -0.999960 |",
448-
"| -0.999960 |",
449-
"| -0.999960 |",
450-
"| -0.999950 |",
451-
"| -0.999950 |",
452-
"| -0.999950 |",
453-
"| -0.999950 |",
454-
"| -0.999950 |",
455-
"+----------------------------------------------------+",
437+
"+------------------------------+",
438+
"| decimal_simple.c1 - Int64(1) |",
439+
"+------------------------------+",
440+
"| -0.999990 |",
441+
"| -0.999980 |",
442+
"| -0.999980 |",
443+
"| -0.999970 |",
444+
"| -0.999970 |",
445+
"| -0.999970 |",
446+
"| -0.999960 |",
447+
"| -0.999960 |",
448+
"| -0.999960 |",
449+
"| -0.999960 |",
450+
"| -0.999950 |",
451+
"| -0.999950 |",
452+
"| -0.999950 |",
453+
"| -0.999950 |",
454+
"| -0.999950 |",
455+
"+------------------------------+",
456456
];
457457
assert_batches_eq!(expected, &actual);
458458

@@ -492,25 +492,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
492492
actual[0].schema().field(0).data_type()
493493
);
494494
let expected = vec![
495-
"+-----------------------------------------------------+",
496-
"| decimal_simple.c1 * Decimal128(Some(20000000),31,6) |",
497-
"+-----------------------------------------------------+",
498-
"| 0.000200 |",
499-
"| 0.000400 |",
500-
"| 0.000400 |",
501-
"| 0.000600 |",
502-
"| 0.000600 |",
503-
"| 0.000600 |",
504-
"| 0.000800 |",
505-
"| 0.000800 |",
506-
"| 0.000800 |",
507-
"| 0.000800 |",
508-
"| 0.001000 |",
509-
"| 0.001000 |",
510-
"| 0.001000 |",
511-
"| 0.001000 |",
512-
"| 0.001000 |",
513-
"+-----------------------------------------------------+",
495+
"+-------------------------------+",
496+
"| decimal_simple.c1 * Int64(20) |",
497+
"+-------------------------------+",
498+
"| 0.000200 |",
499+
"| 0.000400 |",
500+
"| 0.000400 |",
501+
"| 0.000600 |",
502+
"| 0.000600 |",
503+
"| 0.000600 |",
504+
"| 0.000800 |",
505+
"| 0.000800 |",
506+
"| 0.000800 |",
507+
"| 0.000800 |",
508+
"| 0.001000 |",
509+
"| 0.001000 |",
510+
"| 0.001000 |",
511+
"| 0.001000 |",
512+
"| 0.001000 |",
513+
"+-------------------------------+",
514514
];
515515
assert_batches_eq!(expected, &actual);
516516

datafusion/core/tests/sql/expr.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -605,23 +605,23 @@ async fn test_string_concat_operator() -> Result<()> {
605605
let sql = "SELECT 'aa' || NULL || 'd'";
606606
let actual = execute_to_batches(&ctx, sql).await;
607607
let expected = vec![
608-
"+------------+",
609-
"| Utf8(NULL) |",
610-
"+------------+",
611-
"| |",
612-
"+------------+",
608+
"+---------------------------------+",
609+
"| Utf8(\"aa\") || NULL || Utf8(\"d\") |",
610+
"+---------------------------------+",
611+
"| |",
612+
"+---------------------------------+",
613613
];
614614
assert_batches_eq!(expected, &actual);
615615

616616
// concat 1 strings and 2 numeric
617617
let sql = "SELECT 'a' || 42 || 23.3";
618618
let actual = execute_to_batches(&ctx, sql).await;
619619
let expected = vec![
620-
"+-----------------+",
621-
"| Utf8(\"a4223.3\") |",
622-
"+-----------------+",
623-
"| a4223.3 |",
624-
"+-----------------+",
620+
"+-----------------------------------------+",
621+
"| Utf8(\"a\") || Int64(42) || Float64(23.3) |",
622+
"+-----------------------------------------+",
623+
"| a4223.3 |",
624+
"+-----------------------------------------+",
625625
];
626626
assert_batches_eq!(expected, &actual);
627627
Ok(())

datafusion/core/tests/sql/joins.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,8 +1423,6 @@ async fn hash_join_with_dictionary() -> Result<()> {
14231423
}
14241424

14251425
#[tokio::test]
1426-
#[ignore]
1427-
// https://github.com/apache/arrow-datafusion/issues/3565
14281426
async fn reduce_left_join_1() -> Result<()> {
14291427
let ctx = create_join_context("t1_id", "t2_id")?;
14301428

@@ -1469,8 +1467,6 @@ async fn reduce_left_join_1() -> Result<()> {
14691467
}
14701468

14711469
#[tokio::test]
1472-
#[ignore]
1473-
// https://github.com/apache/arrow-datafusion/issues/3565
14741470
async fn reduce_left_join_2() -> Result<()> {
14751471
let ctx = create_join_context("t1_id", "t2_id")?;
14761472

@@ -1514,8 +1510,6 @@ async fn reduce_left_join_2() -> Result<()> {
15141510
}
15151511

15161512
#[tokio::test]
1517-
#[ignore]
1518-
// https://github.com/apache/arrow-datafusion/issues/3565
15191513
async fn reduce_left_join_3() -> Result<()> {
15201514
let ctx = create_join_context("t1_id", "t2_id")?;
15211515

datafusion/core/tests/sql/predicates.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ async fn csv_in_set_test() -> Result<()> {
389389

390390
#[tokio::test]
391391
async fn multiple_or_predicates() -> Result<()> {
392+
// TODO https://github.com/apache/arrow-datafusion/issues/3587
392393
let ctx = SessionContext::new();
393394
register_tpch_csv(&ctx, "lineitem").await?;
394395
register_tpch_csv(&ctx, "part").await?;
@@ -424,15 +425,13 @@ async fn multiple_or_predicates() -> Result<()> {
424425
let plan = state.optimize(&plan)?;
425426
// Note that we expect `#part.p_partkey = #lineitem.l_partkey` to have been
426427
// factored out and appear only once in the following plan
427-
let expected =vec![
428+
let expected = vec![
428429
"Explain [plan_type:Utf8, plan:Utf8]",
429430
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
430-
" Projection: #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
431-
" Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
432-
" Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
433-
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
434-
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
435-
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
431+
" Filter: #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND CAST(#part.p_size AS Int64) BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND CAST(#part.p_size AS Int64) BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND CAST(#part.p_size AS Int64) BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
432+
" Inner Join: #lineitem.l_partkey = #part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
433+
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
434+
" TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
436435
];
437436
let formatted = plan.display_indent_schema().to_string();
438437
let actual: Vec<&str> = formatted.trim().lines().collect();

datafusion/core/tests/sql/select.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,12 +523,12 @@ async fn use_between_expression_in_select_query() -> Result<()> {
523523
.unwrap()
524524
.to_string();
525525

526+
// TODO https://github.com/apache/arrow-datafusion/issues/3587
526527
// Only test that the projection exprs are correct, rather than entire output
527528
let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1 BETWEEN Int64(2) AND Int64(3)]";
528529
assert_contains!(&formatted, needle);
529-
let needle = "Projection: #test.c1 >= Int64(2) AND #test.c1 <= Int64(3)";
530+
let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)";
530531
assert_contains!(&formatted, needle);
531-
532532
Ok(())
533533
}
534534

0 commit comments

Comments
 (0)