diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 30b3c6e2bbeb3..679e7394a1573 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -270,7 +270,7 @@ async fn scalar_udf_zero_params() -> Result<()> { let get_100_udf = Simple0ArgsScalarUDF { name: "get_100".to_string(), - signature: Signature::exact(vec![], Volatility::Immutable), + signature: Signature::nullary(Volatility::Immutable), return_type: DataType::Int32, }; @@ -1119,6 +1119,61 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_valid_zero_argument_signatures() { + let signatures = vec![Signature::nullary(Volatility::Immutable)]; + for signature in signatures { + let ctx = SessionContext::new(); + let udf = ScalarFunctionWrapper { + name: "good_signature".to_string(), + expr: lit(1), + signature, + return_type: DataType::Int32, + }; + ctx.register_udf(ScalarUDF::from(udf)); + let results = ctx + .sql("select good_signature()") + .await + .unwrap() + .collect() + .await + .unwrap(); + let expected = [ + "+------------------+", + "| good_signature() |", + "+------------------+", + "| 1 |", + "+------------------+", + ]; + assert_batches_eq!(expected, &results); + } +} + +#[tokio::test] +async fn test_invalid_zero_argument_signatures() { + let signatures = vec![ + Signature::variadic(vec![], Volatility::Immutable), + Signature::variadic_any(Volatility::Immutable), + Signature::uniform(0, vec![], Volatility::Immutable), + Signature::coercible(vec![], Volatility::Immutable), + Signature::comparable(0, Volatility::Immutable), + Signature::any(0, Volatility::Immutable), + Signature::nullary(Volatility::Immutable), + ]; + for signature in signatures { + let ctx = SessionContext::new(); + let udf = ScalarFunctionWrapper { + name: "bad_signature".to_string(), + expr: lit(1), + signature, + return_type: DataType::Int32, + }; + ctx.register_udf(ScalarUDF::from(udf)); + let results = ctx.sql("select bad_signature()").await.unwrap_err(); + assert_contains!(results.to_string(), "Error during planning: Error during planning: bad_signature does not support zero arguments"); + } +} + /// Saves whatever is passed to it as a scalar function #[derive(Debug, Default)] struct RecordingFunctionFactory { diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 77ba1858e35b0..6a83a9e62bad4 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -342,7 +342,6 @@ impl TypeSignature { /// Check whether 0 input argument is valid for given `TypeSignature` pub fn supports_zero_argument(&self) -> bool { match &self { - TypeSignature::Exact(vec) => vec.is_empty(), TypeSignature::Nullary => true, TypeSignature::OneOf(types) => types .iter() @@ -613,48 +612,6 @@ mod tests { use super::*; - #[test] - fn supports_zero_argument_tests() { - // Testing `TypeSignature`s which supports 0 arg - let positive_cases = vec![ - TypeSignature::Exact(vec![]), - TypeSignature::OneOf(vec![ - TypeSignature::Exact(vec![DataType::Int8]), - TypeSignature::Nullary, - TypeSignature::Uniform(1, vec![DataType::Int8]), - ]), - TypeSignature::Nullary, - ]; - - for case in positive_cases { - assert!( - case.supports_zero_argument(), - "Expected {:?} to support zero arguments", - case - ); - } - - // Testing `TypeSignature`s which doesn't support 0 arg - let negative_cases = vec![ - TypeSignature::Exact(vec![DataType::Utf8]), - TypeSignature::Uniform(1, vec![DataType::Float64]), - TypeSignature::Any(1), - TypeSignature::VariadicAny, - TypeSignature::OneOf(vec![ - TypeSignature::Exact(vec![DataType::Int8]), - TypeSignature::Uniform(1, vec![DataType::Int8]), - ]), - ]; - - for case in negative_cases { - assert!( - !case.supports_zero_argument(), - "Expected {:?} not to support zero arguments", - case - ); - } - } - #[test] fn type_signature_partial_ord() { // Test validates that partial ord is defined for TypeSignature and Signature. diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 7d2906e1731be..6b43632c4045d 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -55,7 +55,7 @@ pub fn data_types_with_scalar_udf( if signature.type_signature.supports_zero_argument() { return Ok(vec![]); } else { - return plan_err!("{} does not support zero arguments.", func.name()); + return plan_err!("{} does not support zero arguments. Please add TypeSignature::Nullary to your function's signature", func.name()); } } @@ -88,21 +88,19 @@ pub fn data_types_with_aggregate_udf( current_types: &[DataType], func: &AggregateUDF, ) -> Result> { - let signature = func.signature(); + let type_signature = &func.signature().type_signature; if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); } else { return plan_err!("{} does not support zero arguments.", func.name()); } } - let valid_types = get_valid_types_with_aggregate_udf( - &signature.type_signature, - current_types, - func, - )?; + let valid_types = + get_valid_types_with_aggregate_udf(type_signature, current_types, func)?; + if valid_types .iter() .any(|data_type| data_type == current_types) @@ -110,12 +108,7 @@ pub fn data_types_with_aggregate_udf( return Ok(current_types.to_vec()); } - try_coerce_types( - func.name(), - valid_types, - current_types, - &signature.type_signature, - ) + try_coerce_types(func.name(), valid_types, current_types, type_signature) } /// Performs type coercion for window function arguments. @@ -129,10 +122,10 @@ pub fn data_types_with_window_udf( current_types: &[DataType], func: &WindowUDF, ) -> Result> { - let signature = func.signature(); + let type_signature = &func.signature().type_signature; if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); } else { return plan_err!("{} does not support zero arguments.", func.name()); @@ -140,7 +133,8 @@ pub fn data_types_with_window_udf( } let valid_types = - get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?; + get_valid_types_with_window_udf(type_signature, current_types, func)?; + if valid_types .iter() .any(|data_type| data_type == current_types) @@ -148,12 +142,7 @@ pub fn data_types_with_window_udf( return Ok(current_types.to_vec()); } - try_coerce_types( - func.name(), - valid_types, - current_types, - &signature.type_signature, - ) + try_coerce_types(func.name(), valid_types, current_types, type_signature) } /// Performs type coercion for function arguments. @@ -168,18 +157,20 @@ pub fn data_types( current_types: &[DataType], signature: &Signature, ) -> Result> { + let type_signature = &signature.type_signature; + if current_types.is_empty() { - if signature.type_signature.supports_zero_argument() { + if type_signature.supports_zero_argument() { return Ok(vec![]); } else { return plan_err!( - "signature {:?} does not support zero arguments.", - &signature.type_signature + "{} does not support zero arguments.", + function_name.as_ref() ); } } - let valid_types = get_valid_types(&signature.type_signature, current_types)?; + let valid_types = get_valid_types(type_signature, current_types)?; if valid_types .iter() .any(|data_type| data_type == current_types) @@ -187,12 +178,7 @@ pub fn data_types( return Ok(current_types.to_vec()); } - try_coerce_types( - function_name, - valid_types, - current_types, - &signature.type_signature, - ) + try_coerce_types(function_name, valid_types, current_types, type_signature) } fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { @@ -335,6 +321,7 @@ fn get_valid_types_with_window_udf( } /// Returns a Vec of all possible valid argument types for the given signature. +/// Empty argument is checked by the caller so no need to re-check here. fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], @@ -441,12 +428,6 @@ fn get_valid_types( } fn function_length_check(length: usize, expected_length: usize) -> Result<()> { - if length < 1 { - return plan_err!( - "The signature expected at least one argument but received {expected_length}" - ); - } - if length != expected_length { return plan_err!( "The signature expected {length} arguments but received {expected_length}" @@ -645,27 +626,16 @@ fn get_valid_types( vec![new_types] } - TypeSignature::Uniform(number, valid_types) => { - if *number == 0 { - return plan_err!("The function expected at least one argument"); - } - - valid_types - .iter() - .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) - .collect() - } + TypeSignature::Uniform(number, valid_types) => valid_types + .iter() + .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) + .collect(), TypeSignature::UserDefined => { return internal_err!( "User-defined signature should be handled by function-specific coerce_types." ) } TypeSignature::VariadicAny => { - if current_types.is_empty() { - return plan_err!( - "The function expected at least one argument but received 0" - ); - } vec![current_types.to_vec()] } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], @@ -716,28 +686,13 @@ fn get_valid_types( } }, TypeSignature::Nullary => { - if !current_types.is_empty() { - return plan_err!( - "The function expected zero argument but received {}", - current_types.len() - ); - } - vec![vec![]] + return plan_err!( + "Nullary expects zero arguments, but received {}", + current_types.len() + ); } TypeSignature::Any(number) => { - if current_types.is_empty() { - return plan_err!( - "The function expected at least one argument but received 0" - ); - } - - if current_types.len() != *number { - return plan_err!( - "The function expected {} arguments but received {}", - number, - current_types.len() - ); - } + function_length_check(current_types.len(), *number)?; vec![(0..*number).map(|i| current_types[i].clone()).collect()] } TypeSignature::OneOf(types) => types diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index bfc87f28ebebe..2548b36f669e6 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -40,7 +40,7 @@ impl Default for VersionFunc { impl VersionFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Immutable), + signature: Signature::nullary(Volatility::Immutable), } } } diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 6048a70bd8c57..5f0b24232215a 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -42,7 +42,7 @@ impl Default for UuidFunc { impl UuidFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), } } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ff75a6a60f4b8..31c2dac3f0328 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -1683,7 +1683,7 @@ mod test { impl RandomStub { fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), } } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 0b328ad39f558..d07c3407424f7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -3315,7 +3315,7 @@ Projection: a, b // SELECT t.a, t.r FROM (SELECT a, sum(b), TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5; let table_scan = test_table_scan_with_name("test1")?; let fun = ScalarUDF::new_from_impl(TestScalarUDF { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), }); let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); @@ -3349,7 +3349,7 @@ Projection: a, b // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5; let table_scan = test_table_scan_with_name("test1")?; let fun = ScalarUDF::new_from_impl(TestScalarUDF { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), }); let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); let left = LogicalPlanBuilder::from(table_scan).build()?; @@ -3395,7 +3395,7 @@ Projection: a, b // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1; let table_scan = test_table_scan()?; let fun = ScalarUDF::new_from_impl(TestScalarUDF { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), }); let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); let plan = LogicalPlanBuilder::from(table_scan) @@ -3419,7 +3419,7 @@ Projection: a, b // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10; let table_scan = test_table_scan()?; let fun = ScalarUDF::new_from_impl(TestScalarUDF { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), }); let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); let plan = LogicalPlanBuilder::from(table_scan) @@ -3446,7 +3446,7 @@ Projection: a, b fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> { // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10; let fun = ScalarUDF::new_from_impl(TestScalarUDF { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), }); let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![])); let plan = table_scan_with_pushdown_provider_builder( diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e3bcb6da8e533..a1cfdb2dab8f1 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -4131,7 +4131,7 @@ mod tests { impl VolatileUdf { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), } } }