Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
| Expr::GroupingSet(_)
| Expr::Case { .. } => VisitRecursion::Continue,

Expr::Unnest { .. } => {
is_applicable = false;
VisitRecursion::Stop
}
Expr::ScalarFunction(scalar_function) => {
match scalar_function.fun.volatility() {
Volatility::Immutable => VisitRecursion::Continue,
Expand Down
12 changes: 8 additions & 4 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::datasource::source_as_provider;
use crate::execution::context::{ExecutionProps, SessionState};
use crate::logical_expr::utils::generate_sort_key;
use crate::logical_expr::{
Aggregate, EmptyRelation, Join, Projection, Sort, SubqueryAlias, TableScan, Unnest,
Window,
Aggregate, EmptyRelation, Join, Projection, Sort, SubqueryAlias, TableScan,
Unnest as UnnestPlan, Window,
};
use crate::logical_expr::{
CrossJoin, Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType,
Expand Down Expand Up @@ -83,8 +83,9 @@ use datafusion_common::{
use datafusion_expr::expr::{
self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast,
GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, ScalarUDF, TryCast,
WindowFunction,
Unnest, WindowFunction,
};

use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols};
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{DescribeTable, DmlStatement, StringifiedPlan, WriteOp};
Expand Down Expand Up @@ -216,6 +217,9 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {

Ok(name)
}
Expr::Unnest(Unnest { array_exprs, .. }) => {
create_function_physical_name("unnest", false, array_exprs)
}
Expr::ScalarFunction(func) => {
create_function_physical_name(&func.fun.to_string(), false, &func.args)
}
Expand Down Expand Up @@ -1226,7 +1230,7 @@ impl DefaultPhysicalPlanner {

Ok(Arc::new(GlobalLimitExec::new(input, *skip, *fetch)))
}
LogicalPlan::Unnest(Unnest { input, column, schema, options }) => {
LogicalPlan::Unnest(UnnestPlan { input, column, schema, options }) => {
let input = self.create_initial_plan(input, session_state).await?;
let column_exec = schema.index_of_column(column)
.map(|idx| Column::new(&column.name, idx))?;
Expand Down
134 changes: 133 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ use crate::window_function;
use crate::Operator;
use arrow::datatypes::DataType;
use datafusion_common::internal_err;
use datafusion_common::UnnestOptions;
use datafusion_common::not_impl_err;
use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue};
use std::collections::HashSet;
use std::fmt;
Expand Down Expand Up @@ -147,6 +149,8 @@ pub enum Expr {
TryCast(TryCast),
/// A sort expression, that can be used to sort values.
Sort(Sort),
/// Unnest expression
Unnest(Unnest),
/// Represents the call of a built-in scalar function with a set of arguments.
ScalarFunction(ScalarFunction),
/// Represents the call of a user-defined scalar function with arguments.
Expand Down Expand Up @@ -328,6 +332,24 @@ impl Between {
}
}

/// Unnest expression
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Unnest {
/// Arrays to unnest
pub array_exprs: Vec<Expr>,
pub options: UnnestOptions,
}

impl Unnest {
/// Create a new Unnest expression
pub fn new(array_exprs: Vec<Expr>, options: UnnestOptions) -> Self {
Self {
array_exprs,
options,
}
}
}

/// ScalarFunction expression
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarFunction {
Expand Down Expand Up @@ -728,6 +750,7 @@ impl Expr {
Expr::TryCast { .. } => "TryCast",
Expr::WindowFunction { .. } => "WindowFunction",
Expr::Wildcard => "Wildcard",
Expr::Unnest(..) => "Unnest",
}
}

Expand Down Expand Up @@ -1030,6 +1053,47 @@ impl Expr {
pub fn contains_outer(&self) -> bool {
!find_out_reference_exprs(self).is_empty()
}

/// Flatten the nested array expressions until the base array is reached.
/// For example:
/// [[1, 2, 3], [4, 5, 6]] -> [1, 2, 3, 4, 5, 6]
/// [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] -> [1, 2, 3, 4, 5, 6, 7, 8]
/// Panics if the expression is not an unnest expression.
pub fn flatten(&self) -> Self {
self.try_flatten().unwrap()
}

/// Flatten the nested array expressions until the base array is reached.
/// For example:
/// [[1, 2, 3], [4, 5, 6]] => [1, 2, 3, 4, 5, 6]
/// [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] => [1, 2, 3, 4, 5, 6, 7, 8]
/// Returns an error if the expression cannot be flattened.
pub fn try_flatten(&self) -> Result<Self> {
match self {
Self::ScalarFunction(ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::MakeArray,
args,
}) => {
let flatten_args: Vec<Expr> =
args.iter().flat_map(Self::flatten_internal).collect();
Ok(Self::ScalarFunction(ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::MakeArray,
args: flatten_args,
}))
}
_ => not_impl_err!("flatten() is not implemented for {self}"),
}
}

fn flatten_internal(&self) -> Vec<Self> {
match self {
Self::ScalarFunction(ScalarFunction {
fun: built_in_function::BuiltinScalarFunction::MakeArray,
args,
}) => args.iter().flat_map(Self::flatten_internal).collect(),
_ => vec![self.clone()],
}
}
}

#[macro_export]
Expand Down Expand Up @@ -1118,6 +1182,9 @@ impl fmt::Display for Expr {
write!(f, " NULLS LAST")
}
}
Expr::Unnest(Unnest { array_exprs, .. }) => {
fmt_function(f, "unnest", false, array_exprs, false)
}
Expr::ScalarFunction(func) => {
fmt_function(f, &func.fun.to_string(), false, &func.args, true)
}
Expand Down Expand Up @@ -1286,7 +1353,6 @@ fn fmt_function(
false => args.iter().map(|arg| format!("{arg:?}")).collect(),
};

// let args: Vec<String> = args.iter().map(|arg| format!("{:?}", arg)).collect();
let distinct_str = match distinct {
true => "DISTINCT ",
false => "",
Expand Down Expand Up @@ -1452,6 +1518,9 @@ fn create_name(e: &Expr) -> Result<String> {
}
}
}
Expr::Unnest(Unnest { array_exprs, .. }) => {
create_function_name("unnest", false, array_exprs)
}
Expr::ScalarFunction(func) => {
create_function_name(&func.fun.to_string(), false, &func.args)
}
Expand Down Expand Up @@ -1583,6 +1652,69 @@ mod test {
use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};

use super::ScalarFunction;

fn create_make_array_expr(args: &[Expr]) -> Expr {
Expr::ScalarFunction(ScalarFunction::new(
crate::BuiltinScalarFunction::MakeArray,
args.to_vec(),
))
}

#[test]
fn test_flatten() {
let i64_none = ScalarValue::try_from(&DataType::Int64).unwrap();

let arr = create_make_array_expr(&[
create_make_array_expr(&[lit(10i64), lit(20i64), lit(30i64)]),
create_make_array_expr(&[lit(1i64), lit(i64_none.clone()), lit(10i64)]),
create_make_array_expr(&[lit(4i64), lit(5i64), lit(6i64)]),
]);

let flattened = arr.flatten();
assert_eq!(
flattened,
create_make_array_expr(&[
lit(10i64),
lit(20i64),
lit(30i64),
lit(1i64),
lit(i64_none),
lit(10i64),
lit(4i64),
lit(5i64),
lit(6i64),
])
);

// [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] -> [1, 2, 3, 4, 5, 6, 7, 8]
let arr = create_make_array_expr(&[
create_make_array_expr(&[
create_make_array_expr(&[lit(1i64), lit(2i64)]),
create_make_array_expr(&[lit(3i64), lit(4i64)]),
]),
create_make_array_expr(&[
create_make_array_expr(&[lit(5i64), lit(6i64)]),
create_make_array_expr(&[lit(7i64), lit(8i64)]),
]),
]);

let flattened = arr.flatten();
assert_eq!(
flattened,
create_make_array_expr(&[
lit(1i64),
lit(2i64),
lit(3i64),
lit(4i64),
lit(5i64),
lit(6i64),
lit(7i64),
lit(8i64),
])
);
}

#[test]
fn format_case_when() -> Result<()> {
let expr = case(col("a"))
Expand Down
53 changes: 37 additions & 16 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess,
GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort,
TryCast, WindowFunction,
TryCast, Unnest, WindowFunction,
};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
Expand All @@ -28,9 +28,9 @@ use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
internal_err, plan_err, Column, DFField, DFSchema, DataFusionError, ExprSchema,
Result,
Result, not_impl_err,
};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

/// trait to allow expr to typable with respect to a schema
Expand Down Expand Up @@ -88,6 +88,28 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::Unnest(Unnest { array_exprs, .. }) => {
let data_types = array_exprs
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

if data_types.is_empty() {
return internal_err!("Empty expression is not allowed")
}

// Use a HashSet to efficiently check for unique data types
let unique_data_types: HashSet<_> = data_types.iter().collect();

// If there is more than one unique data type, return an error
if unique_data_types.len() > 1 {
return not_impl_err!("Unnest does not support inconsistent data types: {data_types:?}");
}

// Extract the common data type since there is only one unique data type
let return_type = data_types[0].to_owned();
Ok(return_type)
}
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let data_types = args
.iter()
Expand Down Expand Up @@ -129,7 +151,9 @@ impl ExprSchemable for Expr {
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
| Expr::IsNotUnknown(_)
| Expr::Like { .. }
| Expr::SimilarTo { .. } => Ok(DataType::Boolean),
Expr::ScalarSubquery(subquery) => {
Ok(subquery.subquery.schema().field(0).data_type().clone())
}
Expand All @@ -138,28 +162,24 @@ impl ExprSchemable for Expr {
ref right,
ref op,
}) => get_result_type(&left.get_type(schema)?, op, &right.get_type(schema)?),
Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
Expr::Placeholder(Placeholder { data_type, .. }) => {
data_type.clone().ok_or_else(|| {
DataFusionError::Plan(
"Placeholder type could not be resolved".to_owned(),
)
})
}
Expr::Wildcard => {
// Wildcard do not really have a type and do not appear in projections
Ok(DataType::Null)
}
Expr::QualifiedWildcard { .. } => internal_err!(
"QualifiedWildcard expressions are not valid in a logical query plan"
),
Expr::GroupingSet(_) => {
// grouping sets do not really have a type and do not appear in projections
Ok(DataType::Null)
}
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
field_for_index(expr, field, schema).map(|x| x.data_type().clone())
}
Expr::Wildcard | Expr::GroupingSet(_) => {
// They do not really have a type and do not appear in projections
Ok(DataType::Null)
}
Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal(
"QualifiedWildcard expressions are not valid in a logical query plan"
.to_owned(),
)),
}
}

Expand Down Expand Up @@ -231,6 +251,7 @@ impl ExprSchemable for Expr {
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
| Expr::Unnest(..)
| Expr::ScalarFunction(..)
| Expr::ScalarUDF(..)
| Expr::WindowFunction { .. }
Expand Down
Loading