diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f914b62a1452d..f15ffdf423747 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -365,7 +365,7 @@ pub async fn from_substrait_rel( )), }, _ => Err(DataFusionError::Internal( - "invalid join condition expresssion".to_string(), + "invalid join condition expression".to_string(), )), } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 785bfa4ea6a7e..2283415488139 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -32,7 +32,7 @@ use datafusion::logical_expr::expr::{ BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; -use datafusion::prelude::{binary_expr, Expr}; +use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::{ proto::{ @@ -156,7 +156,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(e, p.input.schema(), extension_info)) + .map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { @@ -172,6 +172,7 @@ pub fn to_substrait_rel( let filter_expr = to_substrait_rex( &filter.predicate, filter.input.schema(), + 0, extension_info, )?; Ok(Box::new(Rel { @@ -218,7 +219,7 @@ pub fn to_substrait_rel( let grouping = agg .group_expr .iter() - .map(|e| to_substrait_rex(e, agg.input.schema(), extension_info)) + .map(|e| to_substrait_rex(e, agg.input.schema(), 0, extension_info)) .collect::>>()?; let measures = agg .aggr_expr @@ -281,45 +282,24 @@ pub fn to_substrait_rel( } else { Operator::Eq }; - let join_expression = join - .on - .iter() - .map(|(l, r)| binary_expr(l.clone(), eq_op, r.clone())) - .reduce(|acc: Expr, expr: Expr| acc.and(expr)); - // join schema from left and right to maintain all nececesary columns from inputs - // note that we cannot simple use join.schema here since we discard some input columns - // when performing semi and anti joins - let join_schema = match join.left.schema().join(join.right.schema()) { - Ok(schema) => Ok(schema), - Err(DataFusionError::SchemaError( - datafusion::common::SchemaError::DuplicateQualifiedField { - qualifier: _, - name: _, - }, - )) => Ok(join.schema.as_ref().clone()), - Err(e) => Err(e), - }; - if let Some(e) = join_expression { - Ok(Box::new(Rel { - rel_type: Some(RelType::Join(Box::new(JoinRel { - common: None, - left: Some(left), - right: Some(right), - r#type: join_type as i32, - expression: Some(Box::new(to_substrait_rex( - &e, - &Arc::new(join_schema?), - extension_info, - )?)), - post_join_filter: None, - advanced_extension: None, - }))), - })) - } else { - Err(DataFusionError::NotImplemented( - "Empty join condition".to_string(), - )) - } + + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type as i32, + expression: Some(Box::new(to_substrait_join_expr( + &join.on, + eq_op, + join.left.schema(), + join.right.schema(), + extension_info, + )?)), + post_join_filter: None, + advanced_extension: None, + }))), + })) } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias @@ -353,6 +333,7 @@ pub fn to_substrait_rel( window_exprs.push(to_substrait_rex( expr, window.input.schema(), + 0, extension_info, )?); } @@ -403,6 +384,40 @@ pub fn to_substrait_rel( } } +fn to_substrait_join_expr( + join_conditions: &Vec<(Expr, Expr)>, + eq_op: Operator, + left_schema: &DFSchemaRef, + right_schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + // Only support AND conjunction for each binary expression in join conditions + let mut exprs: Vec = vec![]; + for (left, right) in join_conditions { + // Parse left + let l = to_substrait_rex(left, left_schema, 0, extension_info)?; + // Parse right + let r = to_substrait_rex( + right, + right_schema, + left_schema.fields().len(), // offset to return the correct index + extension_info, + )?; + // AND with existing expression + exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info)); + } + let join_expr: Expression = exprs + .into_iter() + .reduce(|acc: Expression, e: Expression| { + make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info) + }) + .unwrap(); + Ok(join_expr) +} + fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { match join_type { JoinType::Inner => join_rel::JoinType::Inner, @@ -459,7 +474,7 @@ pub fn to_substrait_agg_measure( Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by: _order_by }) => { let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); } let function_name = fun.to_string().to_lowercase(); let function_anchor = _register_function(function_name, extension_info); @@ -478,7 +493,7 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, extension_info)?), + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), None => None } }) @@ -566,10 +581,33 @@ pub fn make_binary_op_scalar_func( } /// Convert DataFusion Expr to Substrait Rex +/// +/// # Arguments +/// +/// * `expr` - DataFusion expression to be parse into a Substrait expression +/// * `schema` - DataFusion input schema for looking up field qualifiers +/// * `col_ref_offset` - Offset for caculating Substrait field reference indices. +/// This should only be set by caller with more than one input relations i.e. Join. +/// Substrait expects one set of indices when joining two relations. +/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` +/// relation will have column indices from `0` to `n-1`, however, Substrait will expect +/// the `right` indices to be offset by the `left`. This means Substrait will expect to +/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: +/// ```SELECT * +/// FROM t1 +/// JOIN t2 +/// ON t1.c1 = t2.c0;``` +/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] +/// the join condition should become +/// `col_ref(1) = col_ref(3 + 0)` +/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index +/// of the join key column from `right` +/// * `extension_info` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( expr: &Expr, schema: &DFSchemaRef, + col_ref_offset: usize, extension_info: &mut ( Vec, HashMap, @@ -583,6 +621,7 @@ pub fn to_substrait_rex( arg_type: Some(ArgType::Value(to_substrait_rex( arg, schema, + col_ref_offset, extension_info, )?)), }); @@ -607,9 +646,12 @@ pub fn to_substrait_rex( }) => { if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - let substrait_low = to_substrait_rex(low, schema, extension_info)?; - let substrait_high = to_substrait_rex(high, schema, extension_info)?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + let substrait_low = + to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + let substrait_high = + to_substrait_rex(high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -632,9 +674,12 @@ pub fn to_substrait_rex( )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - let substrait_low = to_substrait_rex(low, schema, extension_info)?; - let substrait_high = to_substrait_rex(high, schema, extension_info)?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + let substrait_low = + to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + let substrait_high = + to_substrait_rex(high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -659,11 +704,11 @@ pub fn to_substrait_rex( } Expr::Column(col) => { let index = schema.index_of_column(col)?; - substrait_field_ref(index) + substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(left, schema, extension_info)?; - let r = to_substrait_rex(right, schema, extension_info)?; + let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?; + let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } @@ -677,21 +722,41 @@ pub fn to_substrait_rex( if let Some(e) = expr { // Base expression exists ifs.push(IfClause { - r#if: Some(to_substrait_rex(e, schema, extension_info)?), + r#if: Some(to_substrait_rex( + e, + schema, + col_ref_offset, + extension_info, + )?), then: None, }); } // Parse `when`s for (r#if, then) in when_then_expr { ifs.push(IfClause { - r#if: Some(to_substrait_rex(r#if, schema, extension_info)?), - then: Some(to_substrait_rex(then, schema, extension_info)?), + r#if: Some(to_substrait_rex( + r#if, + schema, + col_ref_offset, + extension_info, + )?), + then: Some(to_substrait_rex( + then, + schema, + col_ref_offset, + extension_info, + )?), }); } // Parse outer `else` let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)), + Some(e) => Some(Box::new(to_substrait_rex( + e, + schema, + col_ref_offset, + extension_info, + )?)), None => None, }; @@ -707,6 +772,7 @@ pub fn to_substrait_rex( input: Some(Box::new(to_substrait_rex( expr, schema, + col_ref_offset, extension_info, )?)), failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED @@ -715,7 +781,9 @@ pub fn to_substrait_rex( }) } Expr::Literal(value) => to_substrait_literal(value), - Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info), + Expr::Alias(expr, _alias) => { + to_substrait_rex(expr, schema, col_ref_offset, extension_info) + } Expr::WindowFunction(WindowFunction { fun, args, @@ -733,6 +801,7 @@ pub fn to_substrait_rex( arg_type: Some(ArgType::Value(to_substrait_rex( arg, schema, + col_ref_offset, extension_info, )?)), }); @@ -740,7 +809,7 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(e, schema, extension_info)) + .map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info)) .collect::>>()?; // order by expressions let order_by = order_by @@ -1325,7 +1394,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(expr, schema, extension_info)?; + let e = to_substrait_rex(expr, schema, 0, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs index 8cdf89b294730..e209ebedc0f38 100644 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs @@ -412,6 +412,30 @@ mod tests { roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await } + #[tokio::test] + async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a", + "Projection: data.b, data.c\ + \n Inner Join: data.a = data.a\ + \n TableScan: data projection=[a, b]\ + \n TableScan: data projection=[a, c]", + ) + .await + } + + #[tokio::test] + async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", + "Projection: data.b, data.c\ + \n Inner Join: data.b = data.b\ + \n TableScan: data projection=[b]\ + \n TableScan: data projection=[b, c]", + ) + .await + } + /// Construct a plan that contains several literals of types that are currently supported. /// This case ignores: /// - Date64, for this literal is not supported