diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 3052ccf2b68f9..e45d2ef17bff1 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -71,6 +71,8 @@ mod runtime_config; pub mod select; mod sql_api; +mod union_comparison; + async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { let testdata = test_util::arrow_test_data(); diff --git a/datafusion/core/tests/sql/union_comparison.rs b/datafusion/core/tests/sql/union_comparison.rs new file mode 100644 index 0000000000000..783f39fbd8ed0 --- /dev/null +++ b/datafusion/core/tests/sql/union_comparison.rs @@ -0,0 +1,495 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/* +These are tests for union type comparison coercion + +when comparing a union type with a scalar type, it will: +1. check if the scalar type matches any Union variant (exact type match preferred) +2. if no exact match, check if any variant can be cast to the scalar type +3. if yes, coerce to the scalar type (triggers cast from Union to scalar) +4. cast will extract values from the matching variant (and cast if needed), non-matching variants become null + +It's important to note that the scalar type matching is greedy, meaning if you have a Union type +with variants of the same field, it will match by the smaller type id (since variants are sorted by type id) + + +Some of the *current* limitations are: +1. numeric literals default to int64 +2. if multiple variants can cast to target type, picks the first castable variant (by type id order) + +*/ + +use arrow::array::*; +use arrow::buffer::ScalarBuffer; +use arrow::compute::can_cast_types; +use arrow::datatypes::{DataType, Field, Schema, UnionFields, UnionMode}; +use datafusion::assert_batches_eq; +use datafusion::prelude::*; +use datafusion_common::Result; +use std::sync::Arc; + +// create a Union(Int32, Utf8) sparse union array +fn create_sparse_union_array(values: Vec) -> UnionArray { + let union_fields = UnionFields::new( + vec![0, 1], + vec![ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ); + + let mut int_values = Vec::new(); + let mut str_values = Vec::new(); + let mut type_ids = Vec::new(); + + for value in values { + match value { + UnionValue::Int(v) => { + int_values.push(v); + str_values.push(None); + type_ids.push(0); + } + UnionValue::Str(v) => { + int_values.push(None); + str_values.push(v); + type_ids.push(1); + } + } + } + + let int_array = Int32Array::from(int_values); + let str_array = StringArray::from(str_values); + let type_ids = ScalarBuffer::::from(type_ids); + + UnionArray::try_new( + union_fields, + type_ids, + None, + vec![Arc::new(int_array) as Arc, Arc::new(str_array)], + ) + .unwrap() +} + +#[derive(Debug)] +enum UnionValue { + Int(Option), + Str(Option<&'static str>), +} + +// right now arrow does not support union cast support +// maybe the right thing to do is add functionality there... +#[test] +fn test_arrow_union_cast_support() { + let union_fields = UnionFields::new( + vec![0, 1], + vec![ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ); + let union_type = DataType::Union(union_fields, UnionMode::Sparse); + + // Arrow doesn't support casting Union types natively + assert!(!can_cast_types(&union_type, &DataType::Int64)); + assert!(!can_cast_types(&union_type, &DataType::Int32)); + assert!(!can_cast_types(&union_type, &DataType::Utf8)); +} + +#[tokio::test] +async fn test_union_eq_int32() -> Result<()> { + let union_array = create_sparse_union_array(vec![ + UnionValue::Int(Some(67)), + UnionValue::Str(Some("hello")), + UnionValue::Int(Some(123)), + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "val", + DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(union_array), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test", batch)?; + + let df = ctx + .sql("SELECT id FROM test WHERE val = CAST(67 AS INT)") + .await?; + let results = df.collect().await?; + + let expected = ["+----+", "| id |", "+----+", "| 1 |", "+----+"]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn test_union_eq_string() -> Result<()> { + let union_array = create_sparse_union_array(vec![ + UnionValue::Int(Some(67)), + UnionValue::Str(Some("hello")), + UnionValue::Str(Some("world")), + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "val", + DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(union_array), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test", batch)?; + + let df = ctx.sql("SELECT id FROM test WHERE val = 'hello'").await?; + let results = df.collect().await?; + + let expected = ["+----+", "| id |", "+----+", "| 2 |", "+----+"]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn test_union_comparison_operators() -> Result<()> { + let union_array = create_sparse_union_array(vec![ + UnionValue::Int(Some(10)), + UnionValue::Int(Some(20)), + UnionValue::Int(Some(30)), + UnionValue::Str(Some("foo")), + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "val", + DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(union_array), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test", batch)?; + + // test > - cast literals to Int32 + let df = ctx + .sql("SELECT id FROM test WHERE val > CAST(15 AS INT)") + .await?; + let results = df.collect().await?; + let expected = ["+----+", "| id |", "+----+", "| 2 |", "| 3 |", "+----+"]; + assert_batches_eq!(expected, &results); + + // test < + let df = ctx + .sql("SELECT id FROM test WHERE val < CAST(15 AS INT)") + .await?; + let results = df.collect().await?; + let expected = ["+----+", "| id |", "+----+", "| 1 |", "+----+"]; + assert_batches_eq!(expected, &results); + + // test != + let df = ctx + .sql("SELECT id FROM test WHERE val != CAST(20 AS INT)") + .await?; + let results = df.collect().await?; + let expected = ["+----+", "| id |", "+----+", "| 1 |", "| 3 |", "+----+"]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn test_union_with_null_values() -> Result<()> { + let union_array = create_sparse_union_array(vec![ + UnionValue::Int(Some(10)), + UnionValue::Int(None), // null int + UnionValue::Str(Some("foo")), + UnionValue::Str(None), // null string + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "val", + DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(union_array), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test", batch)?; + + let df = ctx + .sql("SELECT id FROM test WHERE val = CAST(10 AS INT)") + .await?; + let results = df.collect().await?; + let expected = ["+----+", "| id |", "+----+", "| 1 |", "+----+"]; + assert_batches_eq!(expected, &results); + + let df = ctx.sql("SELECT id FROM test WHERE val IS NULL").await?; + let results = df.collect().await?; + + // row 2 has null int and row 4 has null string + // both should appear as null after cast + let expected = ["+----+", "| id |", "+----+", "| 2 |", "| 4 |", "+----+"]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn test_union_non_matching_variants_are_null() -> Result<()> { + let union_array = create_sparse_union_array(vec![ + UnionValue::Int(Some(10)), + UnionValue::Str(Some("hello")), + UnionValue::Int(Some(30)), + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "val", + DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(union_array), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test", batch)?; + + // When casting to Int32, the string variant becomes NULL + let df = ctx + .sql("SELECT id, CAST(val AS INT) as val_int FROM test") + .await?; + let results = df.collect().await?; + + dbg!(&results); + + let expected = [ + "+----+---------+", + "| id | val_int |", + "+----+---------+", + "| 1 | 10 |", + "| 2 | |", // null because it's a string + "| 3 | 30 |", + "+----+---------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +// tests cast-compatible variant matching +// when comparing Union(Int32, Utf8) with Int64, it finds the Int32 variant and casts it +#[tokio::test] +async fn test_union_cast_compatible_variant() -> Result<()> { + let union_fields = UnionFields::new( + vec![0, 1], + vec![ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ); + + let union_array = create_sparse_union_array(vec![ + UnionValue::Int(Some(10)), + UnionValue::Str(Some("hello")), + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "val", + DataType::Union(union_fields, UnionMode::Sparse), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(union_array), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test", batch)?; + + // Int32 variant can be cast to Int64, so this should work + let df = ctx + .sql("SELECT id FROM test WHERE val = CAST(10 AS BIGINT)") + .await?; + let results = df.collect().await?; + + let expected = ["+----+", "| id |", "+----+", "| 1 |", "+----+"]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +// todo: this should also be fixed since arrow-ord now has support for union arrays... +// test this with arrow pointed to main... +#[tokio::test] +async fn test_union_eq_same_union() -> Result<()> { + let union_fields = UnionFields::new( + vec![0, 1], + vec![ + Field::new("int", DataType::Int32, true), + Field::new("str", DataType::Utf8, true), + ], + ); + + let union_array1 = create_sparse_union_array(vec![ + UnionValue::Int(Some(10)), + UnionValue::Str(Some("hello")), + ]); + + let union_array2 = create_sparse_union_array(vec![ + UnionValue::Int(Some(10)), + UnionValue::Str(Some("world")), + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "val1", + DataType::Union(union_fields.clone(), UnionMode::Sparse), + true, + ), + Field::new( + "val2", + DataType::Union(union_fields, UnionMode::Sparse), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(union_array1), + Arc::new(union_array2), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("test", batch)?; + + let df = ctx + .sql("SELECT id FROM test WHERE val1 = val2") + .await + .unwrap(); + + let exec_result = df.collect().await; + assert!(exec_result.is_err()); + let err = exec_result.unwrap_err(); + assert!(err.to_string().contains("no natural order"),); + + Ok(()) +} diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 99c21d4abdb6e..dccb6a2fcd55a 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -17,10 +17,12 @@ //! [`ColumnarValue`] represents the result of evaluating an expression. +use arrow::buffer::NullBuffer; use arrow::{ array::{Array, ArrayRef, Date32Array, Date64Array, NullArray}, - compute::{CastOptions, kernels, max, min}, + compute::{CastOptions, can_cast_types, cast_with_options, max, min}, datatypes::DataType, + error::ArrowError, util::pretty::pretty_format_columns, }; use datafusion_common::internal_datafusion_err; @@ -284,11 +286,16 @@ impl ColumnarValue { match self { ColumnarValue::Array(array) => { ensure_date_array_timestamp_bounds(array, cast_type)?; - Ok(ColumnarValue::Array(kernels::cast::cast_with_options( - array, - cast_type, - &cast_options, - )?)) + + let out = match array.data_type() { + // todo: upstream this to arrow + DataType::Union(_, _) => { + cast_union_array(array, cast_type, &cast_options)? + } + _ => cast_with_options(array, cast_type, &cast_options)?, + }; + + Ok(ColumnarValue::Array(out)) } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( scalar.cast_to_with_options(cast_type, &cast_options)?, @@ -297,6 +304,89 @@ impl ColumnarValue { } } +/// casts a union array to a target type by extracting values from the matching variant +/// +/// first attempts to find an exact type match, then falls back to a cast-compatible variant +/// if the variant type differs from the target type, performs an actual cast +/// +/// Note: rows where the active variant differs gets returned as NULL +fn cast_union_array( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use arrow::array::*; + use arrow::datatypes::UnionMode; + + let union_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::CastError("expected UnionArray".to_string()))?; + + let DataType::Union(fields, mode) = union_array.data_type() else { + return Err(ArrowError::CastError( + "expected Union data type".to_string(), + )); + }; + + let len = union_array.len(); + let type_ids = union_array.type_ids(); + + // we do 2 separate passes, first to find any exact matches and second to find any cast-compatible matches + let matching_type_id = fields + .iter() + .find_map(|(i, f)| (f.data_type() == to_type).then_some(i)) + .or_else(|| { + fields + .iter() + .find_map(|(i, f)| can_cast_types(f.data_type(), to_type).then_some(i)) + }); + + let Some(match_id) = matching_type_id else { + return Ok(new_null_array(to_type, len)); + }; + + let matching_child = union_array.child(match_id); + let needs_cast = matching_child.data_type() != to_type; + + // build indices array and null mask in one pass + let mut indices = Vec::with_capacity(len); + let mut null_mask = Vec::with_capacity(len); + + for (i, &type_id) in type_ids.iter().enumerate() { + if type_id == match_id { + let o = match mode { + UnionMode::Sparse => i, + UnionMode::Dense => union_array.value_offset(i), + }; + indices.push(o as u32); + null_mask.push(true); + } else { + // a dummy index + indices.push(0); + null_mask.push(false); + } + } + + let indices_array = UInt32Array::from(indices); + let mut result = arrow::compute::take(matching_child, &indices_array, None)?; + + if needs_cast { + result = cast_with_options(&result, to_type, cast_options)?; + } + + let null_buffer = NullBuffer::from(null_mask); + let result_with_nulls = make_array( + result + .to_data() + .into_builder() + .nulls(Some(null_buffer)) + .build()?, + ); + + Ok(result_with_nulls) +} + fn ensure_date_array_timestamp_bounds( array: &ArrayRef, cast_type: &DataType, diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 18603991aea1a..02b39441b2577 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -854,6 +854,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { } } +fn union_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + + match (lhs_type, rhs_type) { + (Union(_, _), Union(_, _)) => { + // note: we'll start with a simple equality check, deferring complex cases in later work + // for example: when a union is a subset of another union... + lhs_type + .equals_datatype(rhs_type) + .then_some(lhs_type.clone()) + } + (Union(fields, _), opaque) | (opaque, Union(fields, _)) => fields + .iter() + .any(|(_, f)| can_cast_types(f.data_type(), opaque)) + .then_some(opaque.clone()), + _ => None, + } +} + /// Returns the output type of applying mathematics operations such as /// `+` to arguments of `lhs_type` and `rhs_type`. fn mathematics_numerical_coercion( diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 691a8c508f801..0fddd8a5829f9 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -672,7 +672,20 @@ impl ExprSchemable for Expr { // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? - if can_cast_types(&this_type, cast_to_type) { + // special handling for union types + // Union casting requires exact match because we can't pass type ids through the coercion system + // and there can be duplicate tag ids + // using `can_cast_types` would allow Union(Int32) -> Int64 + // but `cast_union_array` requires exact matching + let can_cast = if let DataType::Union(fields, _) = &this_type { + fields + .iter() + .any(|(_, field)| can_cast_types(field.data_type(), cast_to_type)) + } else { + can_cast_types(&this_type, cast_to_type) + }; + + if can_cast { match self { Expr::ScalarSubquery(subquery) => { Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index bd5c63a69979f..adfb44d327288 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -234,12 +234,25 @@ pub fn cast_with_options( ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { - Ok(Arc::clone(&expr)) - } else if can_cast_types(&expr_type, &cast_type) { - Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) + return Ok(Arc::clone(&expr)); + } + + // special handling for union types + // Union casting requires exact match because we can't pass type ids through the coercion system + // and there can be duplicate tag ids + let can_cast = if let Union(fields, _) = &expr_type { + fields + .iter() + .any(|(_, f)| can_cast_types(f.data_type(), &cast_type)) } else { - not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") + can_cast_types(&expr_type, &cast_type) + }; + + if !can_cast { + return not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}"); } + + Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } /// Return a PhysicalExpression representing `expr` casted to