@@ -28,13 +28,15 @@ use datafusion::from_slice::FromSlice;
2828use std:: sync:: Arc ;
2929
3030use datafusion:: dataframe:: DataFrame ;
31+ use datafusion:: datasource:: MemTable ;
3132use datafusion:: error:: Result ;
3233use datafusion:: execution:: context:: SessionContext ;
3334use datafusion:: prelude:: JoinType ;
3435use datafusion:: prelude:: { CsvReadOptions , ParquetReadOptions } ;
3536use datafusion:: test_util:: parquet_test_data;
3637use datafusion:: { assert_batches_eq, assert_batches_sorted_eq} ;
37- use datafusion_common:: ScalarValue ;
38+ use datafusion_common:: { DataFusionError , ScalarValue } ;
39+ use datafusion_execution:: config:: SessionConfig ;
3840use datafusion_expr:: expr:: { GroupingSet , Sort } ;
3941use datafusion_expr:: utils:: COUNT_STAR_EXPANSION ;
4042use datafusion_expr:: Expr :: Wildcard ;
@@ -43,6 +45,7 @@ use datafusion_expr::{
4345 sum, AggregateFunction , Expr , ExprSchemable , WindowFrame , WindowFrameBound ,
4446 WindowFrameUnits , WindowFunction ,
4547} ;
48+ use datafusion_physical_expr:: var_provider:: { VarProvider , VarType } ;
4649
4750#[ tokio:: test]
4851async fn test_count_wildcard_on_sort ( ) -> Result < ( ) > {
@@ -1230,3 +1233,39 @@ pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Resul
12301233 . await ?;
12311234 Ok ( ( ) )
12321235}
1236+ #[ derive( Debug ) ]
1237+ struct HardcodedIntProvider { }
1238+
1239+ impl VarProvider for HardcodedIntProvider {
1240+ fn get_value ( & self , _var_names : Vec < String > ) -> Result < ScalarValue , DataFusionError > {
1241+ Ok ( ScalarValue :: Int64 ( Some ( 1234 ) ) )
1242+ }
1243+
1244+ fn get_type ( & self , _: & [ String ] ) -> Option < DataType > {
1245+ Some ( DataType :: Int64 )
1246+ }
1247+ }
1248+
1249+ #[ tokio:: test]
1250+ async fn use_var_provider ( ) -> Result < ( ) > {
1251+ let schema = Arc :: new ( Schema :: new ( vec ! [
1252+ Field :: new( "foo" , DataType :: Int64 , false ) ,
1253+ Field :: new( "bar" , DataType :: Int64 , false ) ,
1254+ ] ) ) ;
1255+
1256+ let mem_table = Arc :: new ( MemTable :: try_new ( schema, vec ! [ ] ) ?) ;
1257+
1258+ let config = SessionConfig :: new ( )
1259+ . with_target_partitions ( 4 )
1260+ . set_bool ( "datafusion.optimizer.skip_failed_rules" , false ) ;
1261+ let ctx = SessionContext :: with_config ( config) ;
1262+
1263+ ctx. register_table ( "csv_table" , mem_table) ?;
1264+ ctx. register_variable ( VarType :: UserDefined , Arc :: new ( HardcodedIntProvider { } ) ) ;
1265+
1266+ let dataframe = ctx
1267+ . sql ( "SELECT foo FROM csv_table WHERE bar > @var" )
1268+ . await ?;
1269+ dataframe. collect ( ) . await ?;
1270+ Ok ( ( ) )
1271+ }
0 commit comments