diff --git a/prqlc/prqlc/src/sql/pq/anchor.rs b/prqlc/prqlc/src/sql/pq/anchor.rs index d2d697ddb35b..0e424b7389e7 100644 --- a/prqlc/prqlc/src/sql/pq/anchor.rs +++ b/prqlc/prqlc/src/sql/pq/anchor.rs @@ -1,4 +1,6 @@ use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::ops::Deref; use itertools::Itertools; @@ -79,7 +81,10 @@ pub(super) fn split_off_back( let mut following_transforms: HashSet = HashSet::new(); - let mut inputs_required = into_requirements(output.clone(), Complexity::highest(), true); + let mut inputs_required = Requirements::from_cids(output.iter()) + .allow_up_to(Complexity::highest()) + .should_select(true); + let mut inputs_avail = HashSet::new(); // iterate backwards @@ -97,7 +102,7 @@ pub(super) fn split_off_back( // anchor and record all requirements let required = get_requirements(&transform, &following_transforms); log::debug!(".. transform {} requires {required:?}", transform.as_str(),); - inputs_required.extend(required.clone()); + inputs_required = inputs_required.append(required.clone()); match &transform { SqlTransform::Super(Transform::Compute(compute)) => { @@ -107,11 +112,8 @@ pub(super) fn split_off_back( inputs_avail.insert(compute.id); // add transitive dependencies - inputs_required.extend(required.into_iter().map(|x| Requirement { - col: x.col, - max_complexity, - selected: false, - })); + inputs_required = inputs_required + .append(required.allow_up_to(max_complexity).should_select(false)); } else { pipeline.push(transform); break; @@ -152,11 +154,7 @@ pub(super) fn split_off_back( log::debug!("finished table:"); log::debug!(".. avail={inputs_avail:?}"); - let required = inputs_required - .into_iter() - .map(|r| r.col) - .unique() - .collect_vec(); + let required = inputs_required.iter().map(|r| r.col).unique().collect_vec(); log::debug!(".. required={required:?}"); let missing = required @@ -419,18 +417,72 @@ pub struct Requirement { pub selected: bool, } -fn into_requirements( - cids: Vec, - max_complexity: Complexity, - selected: bool, -) -> Vec { - cids.into_iter() - .map(|col| Requirement { - col, - max_complexity, - selected, - }) - .collect() +#[derive(Clone, Default)] +pub struct Requirements(Vec); + +/// To iter on `Requirements` as if it's a simple `Vec`. +impl Deref for Requirements { + type Target = [Requirement]; + + fn deref(&self) -> &[Requirement] { + self.0.as_slice() + } +} + +/// To debug `Requirements` as if it's a simple `Vec`. +impl fmt::Debug for Requirements { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl Requirements { + /// Turns a list of `CId` into requirements with the least allowed complexity + /// and unselected by default. + pub fn from_cids<'a, I>(cids: I) -> Requirements + where + I: Iterator, + { + Requirements( + cids.cloned() + .map(|col| Requirement { + col, + max_complexity: Complexity::lowest(), + selected: false, + }) + .collect(), + ) + } + + /// Collect columns from the given `Expr` into requirements with + /// the least allowed complexity and unselected by default. + pub fn from_expr(expr: &Expr) -> Requirements { + let cids = CidCollector::collect(expr.clone()); + Requirements::from_cids(cids.iter()) + } + + /// Moves all the elements of `other` into `self`, leaving `other` empty, + /// then return `self` for chainability. + pub fn append(mut self, mut other: Requirements) -> Requirements { + self.0.append(&mut other.0); + self + } + + /// Set a maximum complexity to all stored requirements. + pub fn allow_up_to(mut self, max_complexity: Complexity) -> Self { + for r in &mut self.0 { + r.max_complexity = max_complexity; + } + self + } + + /// Set a SELECT status to all stored requirements. + pub fn should_select(mut self, selected: bool) -> Self { + for r in &mut self.0 { + r.selected = selected; + } + self + } } impl std::fmt::Debug for Requirement { @@ -444,100 +496,78 @@ impl std::fmt::Debug for Requirement { pub(super) fn get_requirements( transform: &SqlTransform, following: &HashSet, -) -> Vec { +) -> Requirements { use SqlTransform::Super; - use Transform::*; - - // special case for Aggregate, which contain two difference Complexity-ies - if let Super(Aggregate { partition, compute }) = transform { - let mut r = Vec::new(); - r.extend(into_requirements( - partition.clone(), - Complexity::Plain, - false, - )); - r.extend(into_requirements( - compute.clone(), - Complexity::Aggregation, - false, - )); - return r; - } - - // special case for Compute, which contain two difference Complexity-ies - if let Super(Compute(compute)) = transform { - // expr itself - let expr_cids = CidCollector::collect(compute.expr.clone()); - - let expr_max_complexity = match infer_complexity(compute) { - // plain expressions can be included in anything less complex than Aggregation - Complexity::Plain => Complexity::Aggregation, - - // anything more complex can only use included in other plain expressions. - // in other words: complex expressions (aggregation, window functions) cannot - // be defined within other expressions. - _ => Complexity::Plain, - }; - let mut requirements = into_requirements(expr_cids, expr_max_complexity, false); - // window - if let Some(window) = &compute.window { - // TODO: what kind of exprs can be in window frame? - // window.frame + match transform { + Super(Transform::Aggregate { partition, compute }) => { + let partition_requirements = Requirements::from_cids(partition.iter()); + let compute_requirements = + Requirements::from_cids(compute.iter()).allow_up_to(Complexity::Aggregation); - let mut window_cids = window.partition.clone(); - window_cids.extend(window.sort.iter().map(|s| s.column)); - - requirements.extend(into_requirements(window_cids, Complexity::Plain, false)); + partition_requirements.append(compute_requirements) } - return requirements; - } - - // general case: extract cids - let cids = match transform { - Super(Compute(compute)) => CidCollector::collect(compute.expr.clone()), - Super(Filter(expr)) | SqlTransform::Join { filter: expr, .. } => { - CidCollector::collect(expr.clone()) - } - // Aggregations require that all selected columns be wrapped in aggregate functions (e.g., SUM, COUNT). - Super(Sort(sorts)) if !following.contains("Aggregate") => { - sorts.iter().map(|s| s.column).collect() - } - Super(Take(rq::Take { range, .. })) => { - let mut cids = Vec::new(); - if let Some(e) = &range.start { - cids.extend(CidCollector::collect(e.clone())); - } - if let Some(e) = &range.end { - cids.extend(CidCollector::collect(e.clone())); + Super(Transform::Compute(compute)) => { + let requirements = Requirements::from_expr(&compute.expr).allow_up_to( + match infer_complexity(compute) { + // plain expressions can be included in anything less complex than Aggregation + Complexity::Plain => Complexity::Aggregation, + + // anything more complex can only use included in other plain expressions. + // in other words: complex expressions (aggregation, window functions) cannot + // be defined within other expressions. + _ => Complexity::Plain, + }, + ); + + if let Some(window) = &compute.window { + // TODO: what kind of exprs can be in window frame? + // window.frame + + let window_cids = window + .partition + .iter() + .chain(window.sort.iter().map(|s| &s.column)); + + requirements.append(Requirements::from_cids(window_cids)) + } else { + requirements } - cids } - _ => return Vec::new(), - }; - - // general case: determine complexity - let (max_complexity, selected) = match transform { - Super(Filter(_)) => ( - if !following.contains("Aggregate") { + Super(Transform::Filter(expr)) => { + Requirements::from_expr(expr).allow_up_to(if !following.contains("Aggregate") { Complexity::Aggregation } else { Complexity::Plain - }, - false, - ), - // we only use SELECTed columns in ORDER BY, so the columns can have high complexity - Super(Sort(_)) => (Complexity::Aggregation, true), + }) + } - // LIMIT and OFFSET can use constant expressions which don't need to be SELECTed - Super(Take(_)) => (Complexity::Plain, false), - SqlTransform::Join { .. } => (Complexity::Plain, false), - _ => unreachable!(), - }; + // Aggregations require that all selected columns be wrapped in aggregate functions (e.g., SUM, COUNT). + Super(Transform::Sort(sorts)) if !following.contains("Aggregate") => { + Requirements::from_cids(sorts.iter().map(|s| &s.column)) + // we only use SELECTed columns in ORDER BY, so the columns can have high complexity + .allow_up_to(Complexity::Aggregation) + .should_select(true) + } - into_requirements(cids, max_complexity, selected) + SqlTransform::DistinctOn(partition) => Requirements::from_cids(partition.iter()) + // Partition columns must be selected in order to push compute columns down CTE. + .should_select(true) + // Since there is aggregation anyway, columns can have any complexity + .allow_up_to(Complexity::highest()), + + Super(Transform::Take(rq::Take { range, .. })) => [&range.start, &range.end] + .into_iter() + .flatten() + .map(Requirements::from_expr) + .fold(Requirements::default(), Requirements::append), + + SqlTransform::Join { filter, .. } => Requirements::from_expr(filter), + + _ => Requirements::default(), + } } /// Complexity of a column expressions. @@ -554,6 +584,10 @@ pub enum Complexity { } impl Complexity { + const fn lowest() -> Self { + Self::Plain + } + const fn highest() -> Self { Self::Aggregation } diff --git a/prqlc/prqlc/tests/integration/sql.rs b/prqlc/prqlc/tests/integration/sql.rs index 385769f3c8b3..0964a69fc7d8 100644 --- a/prqlc/prqlc/tests/integration/sql.rs +++ b/prqlc/prqlc/tests/integration/sql.rs @@ -2518,11 +2518,17 @@ fn test_distinct_on_03() { derive foo = 1 select foo "###).unwrap()), @r" - WITH table_0 AS ( + WITH table_1 AS ( SELECT - DISTINCT ON (col1) NULL + DISTINCT ON (col1) col1 FROM tab1 + ), + table_0 AS ( + SELECT + NULL + FROM + table_1 ) SELECT 1 AS foo @@ -5669,3 +5675,57 @@ fn test_type_error_placement() { t "); } + +#[test] +fn test_missing_columns_group_complex_compute() { + // https://github.com/PRQL/prql/issues/5354 + // The focus for this tests is on whether the `hire_date` column is available where it's needed. + // Additional `city` derive are there only to trigger the issue. + assert_snapshot!(compile( + r#"prql target:sql.postgres + from employees + derive `year` = s'EXTRACT(year from {`hire_date`})' + derive { `year_label` = f"Year {`year`}" } + derive { `city` = case [ this.`city` == "Calgary" => "A city", true => this.`city` ] } + derive { `city` = case [ this.`city` == "Edmonton" => "Another city", true => this.`city` ] } + group {`year`, `year_label`} (take 1) + select {this.`year_label`} + "#, + ) + .unwrap(), @r" + WITH table_0 AS ( + SELECT + CONCAT( + 'Year ', + EXTRACT( + year + from + hire_date + ) + ) AS year_label, + EXTRACT( + year + from + hire_date + ) AS _expr_0, + CASE + WHEN city = 'Calgary' THEN 'A city' + ELSE city + END AS _expr_1, + city + FROM + employees + ), + table_1 AS ( + SELECT + DISTINCT ON (_expr_0, year_label) year_label, + _expr_0 + FROM + table_0 + ) + SELECT + year_label + FROM + table_1 + "); +}