Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
};
use std::sync::Arc;

/// The rule can be used to the numeric binary comparison with literal expr, like below pattern:
/// `cast(left_expr as data_type) comparison_op literal_expr` or `literal_expr comparison_op cast(right_expr as data_type)`.
Expand Down Expand Up @@ -79,10 +80,18 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
.map(|input| optimize(input))
.collect::<Result<Vec<_>>>()?;

let schema = plan.schema();
let mut schema = new_inputs.iter().map(|input| input.schema()).fold(
DFSchema::empty(),
|mut lhs, rhs| {
lhs.merge(rhs);
lhs
},
);

schema.merge(plan.schema());

let mut expr_rewriter = UnwrapCastExprRewriter {
schema: schema.clone(),
schema: Arc::new(schema),
};

let new_exprs = plan
Expand Down Expand Up @@ -273,6 +282,11 @@ fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Result<Option<ScalarValue>> {
let lit_data_type = lit_value.get_datatype();
// the rule just support the signed numeric data type now
if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) {
return Ok(None);
}
if lit_value.is_null() {
// null value can be cast to any type of null value
return Ok(Some(ScalarValue::try_from(target_type)?));
Expand Down Expand Up @@ -373,7 +387,7 @@ mod tests {
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{cast, col, lit, try_cast, Expr};
use datafusion_expr::{cast, col, in_list, lit, try_cast, Expr};
use std::collections::HashMap;
use std::sync::Arc;

Expand Down Expand Up @@ -585,6 +599,21 @@ mod tests {
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

#[test]
fn test_not_support_data_type() {
// "c6 > 0" will be cast to `cast(c6 as int64) > 0
// but the type of c6 is uint32
// the rewriter will not throw error and just return the original expr
let schema = expr_test_schema();
let expr_input = cast(col("c6"), DataType::Int64).eq(lit(0i64));
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);

// inlist for unsupported data type
let expr_input =
in_list(cast(col("c6"), DataType::Int64), vec![lit(0i64)], false);
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
}

fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
let mut expr_rewriter = UnwrapCastExprRewriter {
schema: schema.clone(),
Expand All @@ -601,6 +630,7 @@ mod tests {
DFField::new(None, "c3", DataType::Decimal128(18, 2), false),
DFField::new(None, "c4", DataType::Decimal128(38, 37), false),
DFField::new(None, "c5", DataType::Float32, false),
DFField::new(None, "c6", DataType::UInt32, false),
],
HashMap::new(),
)
Expand Down
28 changes: 28 additions & 0 deletions datafusion/optimizer/tests/integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,33 @@ use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;

#[test]
fn case_when() -> Result<()> {
let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test";
let plan = test_sql(sql)?;
let expected = "Projection: CASE WHEN #test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This plan doesn't look ideal to me -- since col_int32 is int32 I think the exprs could be rewritten to

CASE WHEN #test.col_int32 > Int32(0) THEN Int32(1) ELSE Int32(0) END

Could be done as a follow on PR for sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb yes,
we can support the unsigned numeric data type in the UnwrapCastInBinaryComparison rule

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{:?}", plan));

let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test";
let plan = test_sql(sql)?;
let expected = "Projection: CASE WHEN CAST(#test.col_uint32 AS Int64) > Int64(0) THEN Int64(1) ELSE Int64(0) END\
\n TableScan: test projection=[col_uint32]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn unsigned_target_type() -> Result<()> {
let sql = "SELECT * FROM test WHERE col_uint32 > 0";
let plan = test_sql(sql)?;
let expected = "Projection: #test.col_int32, #test.col_uint32, #test.col_utf8, #test.col_date32, #test.col_date64\
\n Filter: CAST(#test.col_uint32 AS Int64) > Int64(0)\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn distribute_by() -> Result<()> {
// regression test for https://github.com/apache/arrow-datafusion/issues/3234
Expand Down Expand Up @@ -114,6 +141,7 @@ impl ContextProvider for MySchemaProvider {
let schema = Schema::new_with_metadata(
vec![
Field::new("col_int32", DataType::Int32, true),
Field::new("col_uint32", DataType::UInt32, true),
Field::new("col_utf8", DataType::Utf8, true),
Field::new("col_date32", DataType::Date32, true),
Field::new("col_date64", DataType::Date64, true),
Expand Down