Skip to content

Commit a40d9ba

Browse files
committed
fix: projection_push_down don't consider VarProvider in columns.
1 parent 439863f commit a40d9ba

2 files changed

Lines changed: 42 additions & 5 deletions

File tree

datafusion/core/tests/dataframe.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ use datafusion::from_slice::FromSlice;
2828
use std::sync::Arc;
2929

3030
use datafusion::dataframe::DataFrame;
31+
use datafusion::datasource::MemTable;
3132
use datafusion::error::Result;
3233
use datafusion::execution::context::SessionContext;
3334
use datafusion::prelude::JoinType;
3435
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
3536
use datafusion::test_util::parquet_test_data;
3637
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
37-
use datafusion_common::ScalarValue;
38+
use datafusion_common::{DataFusionError, ScalarValue};
39+
use datafusion_execution::config::SessionConfig;
3840
use datafusion_expr::expr::{GroupingSet, Sort};
3941
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
4042
use datafusion_expr::Expr::Wildcard;
@@ -43,6 +45,7 @@ use datafusion_expr::{
4345
sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
4446
WindowFrameUnits, WindowFunction,
4547
};
48+
use datafusion_physical_expr::var_provider::{VarProvider, VarType};
4649

4750
#[tokio::test]
4851
async fn test_count_wildcard_on_sort() -> Result<()> {
@@ -1230,3 +1233,39 @@ pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Resul
12301233
.await?;
12311234
Ok(())
12321235
}
1236+
#[derive(Debug)]
1237+
struct HardcodedIntProvider {}
1238+
1239+
impl VarProvider for HardcodedIntProvider {
1240+
fn get_value(&self, _var_names: Vec<String>) -> Result<ScalarValue, DataFusionError> {
1241+
Ok(ScalarValue::Int64(Some(1234)))
1242+
}
1243+
1244+
fn get_type(&self, _: &[String]) -> Option<DataType> {
1245+
Some(DataType::Int64)
1246+
}
1247+
}
1248+
1249+
#[tokio::test]
1250+
async fn use_var_provider() -> Result<()> {
1251+
let schema = Arc::new(Schema::new(vec![
1252+
Field::new("foo", DataType::Int64, false),
1253+
Field::new("bar", DataType::Int64, false),
1254+
]));
1255+
1256+
let mem_table = Arc::new(MemTable::try_new(schema, vec![])?);
1257+
1258+
let config = SessionConfig::new()
1259+
.with_target_partitions(4)
1260+
.set_bool("datafusion.optimizer.skip_failed_rules", false);
1261+
let ctx = SessionContext::with_config(config);
1262+
1263+
ctx.register_table("csv_table", mem_table)?;
1264+
ctx.register_variable(VarType::UserDefined, Arc::new(HardcodedIntProvider {}));
1265+
1266+
let dataframe = ctx
1267+
.sql("SELECT foo FROM csv_table WHERE bar > @var")
1268+
.await?;
1269+
dataframe.collect().await?;
1270+
Ok(())
1271+
}

datafusion/expr/src/utils.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,11 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
270270
Expr::Column(qc) => {
271271
accum.insert(qc.clone());
272272
}
273-
Expr::ScalarVariable(_, var_names) => {
274-
accum.insert(Column::from_name(var_names.join(".")));
275-
}
276273
// Use explicit pattern match instead of a default
277274
// implementation, so that in the future if someone adds
278275
// new Expr types, they will check here as well
279-
Expr::Alias(_, _)
276+
Expr::ScalarVariable(_, _)
277+
| Expr::Alias(_, _)
280278
| Expr::Literal(_)
281279
| Expr::BinaryExpr { .. }
282280
| Expr::Like { .. }

0 commit comments

Comments
 (0)