From 3202ab06eb17bd76ec15a954f2a2ec39fa715f02 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 4 Oct 2022 16:22:28 +0800 Subject: [PATCH 1/2] the unwrap rule: don't throw error when meet unsupported data type --- .../optimizer/src/unwrap_cast_in_comparison.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 0f7238d33cd00..baf9021506bd5 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -273,6 +273,11 @@ fn try_cast_literal_to_type( lit_value: &ScalarValue, target_type: &DataType, ) -> Result> { + let lit_data_type = lit_value.get_datatype(); + // the rule just support the signed numeric data type now + if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) { + return Ok(None); + } if lit_value.is_null() { // null value can be cast to any type of null value return Ok(Some(ScalarValue::try_from(target_type)?)); @@ -585,6 +590,17 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); } + #[test] + fn test_not_support_data_type() { + // "c6 > 0" will be cast to `cast(c6 as int64) > 0 + // but the type of c6 is uint32 + // the rewriter will not throw error and just return the original expr + let schema = expr_test_schema(); + let expr_input = cast(col("c6"), DataType::Int64).eq(lit(0i64)); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); + // TODO: add case when case + } + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), @@ -601,6 +617,7 @@ mod tests { DFField::new(None, "c3", DataType::Decimal128(18, 2), false), DFField::new(None, "c4", DataType::Decimal128(38, 37), false), DFField::new(None, "c5", DataType::Float32, false), + DFField::new(None, "c6", DataType::UInt32, false), ], HashMap::new(), ) From 6f8e4a3dff4e0477ec3bfe8ae8bb636a9095f0bd Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 4 Oct 2022 18:56:41 +0800 Subject: [PATCH 2/2] support data type in unwrap cast rule --- .../src/unwrap_cast_in_comparison.rs | 23 +++++++++++---- .../optimizer/tests/integration-test.rs | 28 +++++++++++++++++++ 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index baf9021506bd5..7d6858362cadf 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -22,12 +22,13 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; -use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; +use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; +use std::sync::Arc; /// The rule can be used to the numeric binary comparison with literal expr, like below pattern: /// `cast(left_expr as data_type) comparison_op literal_expr` or `literal_expr comparison_op cast(right_expr as data_type)`. @@ -79,10 +80,18 @@ fn optimize(plan: &LogicalPlan) -> Result { .map(|input| optimize(input)) .collect::>>()?; - let schema = plan.schema(); + let mut schema = new_inputs.iter().map(|input| input.schema()).fold( + DFSchema::empty(), + |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }, + ); + + schema.merge(plan.schema()); let mut expr_rewriter = UnwrapCastExprRewriter { - schema: schema.clone(), + schema: Arc::new(schema), }; let new_exprs = plan @@ -378,7 +387,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; use datafusion_expr::expr_rewriter::ExprRewritable; - use datafusion_expr::{cast, col, lit, try_cast, Expr}; + use datafusion_expr::{cast, col, in_list, lit, try_cast, Expr}; use std::collections::HashMap; use std::sync::Arc; @@ -598,7 +607,11 @@ mod tests { let schema = expr_test_schema(); let expr_input = cast(col("c6"), DataType::Int64).eq(lit(0i64)); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); - // TODO: add case when case + + // inlist for unsupported data type + let expr_input = + in_list(cast(col("c6"), DataType::Int64), vec![lit(0i64)], false); + assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 86f55e698505f..e7245c06c1021 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -29,6 +29,33 @@ use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +#[test] +fn case_when() -> Result<()> { + let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test"; + let plan = test_sql(sql)?; + let expected = "Projection: CASE WHEN #test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{:?}", plan)); + + let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test"; + let plan = test_sql(sql)?; + let expected = "Projection: CASE WHEN CAST(#test.col_uint32 AS Int64) > Int64(0) THEN Int64(1) ELSE Int64(0) END\ + \n TableScan: test projection=[col_uint32]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + +#[test] +fn unsigned_target_type() -> Result<()> { + let sql = "SELECT * FROM test WHERE col_uint32 > 0"; + let plan = test_sql(sql)?; + let expected = "Projection: #test.col_int32, #test.col_uint32, #test.col_utf8, #test.col_date32, #test.col_date64\ + \n Filter: CAST(#test.col_uint32 AS Int64) > Int64(0)\ + \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + #[test] fn distribute_by() -> Result<()> { // regression test for https://github.com/apache/arrow-datafusion/issues/3234 @@ -114,6 +141,7 @@ impl ContextProvider for MySchemaProvider { let schema = Schema::new_with_metadata( vec![ Field::new("col_int32", DataType::Int32, true), + Field::new("col_uint32", DataType::UInt32, true), Field::new("col_utf8", DataType::Utf8, true), Field::new("col_date32", DataType::Date32, true), Field::new("col_date64", DataType::Date64, true),