Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 139 additions & 105 deletions prqlc/prqlc/src/sql/pq/anchor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::ops::Deref;

use itertools::Itertools;

Expand Down Expand Up @@ -79,7 +81,10 @@ pub(super) fn split_off_back(

let mut following_transforms: HashSet<String> = HashSet::new();

let mut inputs_required = into_requirements(output.clone(), Complexity::highest(), true);
let mut inputs_required = Requirements::from_cids(output.iter())
.allow_up_to(Complexity::highest())
.should_select(true);

let mut inputs_avail = HashSet::new();

// iterate backwards
Expand All @@ -97,7 +102,7 @@ pub(super) fn split_off_back(
// anchor and record all requirements
let required = get_requirements(&transform, &following_transforms);
log::debug!(".. transform {} requires {required:?}", transform.as_str(),);
inputs_required.extend(required.clone());
inputs_required = inputs_required.append(required.clone());

match &transform {
SqlTransform::Super(Transform::Compute(compute)) => {
Expand All @@ -107,11 +112,8 @@ pub(super) fn split_off_back(
inputs_avail.insert(compute.id);

// add transitive dependencies
inputs_required.extend(required.into_iter().map(|x| Requirement {
col: x.col,
max_complexity,
selected: false,
}));
inputs_required = inputs_required
.append(required.allow_up_to(max_complexity).should_select(false));
} else {
pipeline.push(transform);
break;
Expand Down Expand Up @@ -152,11 +154,7 @@ pub(super) fn split_off_back(
log::debug!("finished table:");
log::debug!(".. avail={inputs_avail:?}");

let required = inputs_required
.into_iter()
.map(|r| r.col)
.unique()
.collect_vec();
let required = inputs_required.iter().map(|r| r.col).unique().collect_vec();
log::debug!(".. required={required:?}");

let missing = required
Expand Down Expand Up @@ -419,18 +417,72 @@ pub struct Requirement {
pub selected: bool,
}

fn into_requirements(
cids: Vec<CId>,
max_complexity: Complexity,
selected: bool,
) -> Vec<Requirement> {
cids.into_iter()
.map(|col| Requirement {
col,
max_complexity,
selected,
})
.collect()
#[derive(Clone, Default)]
pub struct Requirements(Vec<Requirement>);

/// To iter on `Requirements` as if it's a simple `Vec`.
impl Deref for Requirements {
type Target = [Requirement];

fn deref(&self) -> &[Requirement] {
self.0.as_slice()
}
}

/// To debug `Requirements` as if it's a simple `Vec`.
impl fmt::Debug for Requirements {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}

impl Requirements {
/// Turns a list of `CId` into requirements with the least allowed complexity
/// and unselected by default.
pub fn from_cids<'a, I>(cids: I) -> Requirements
where
I: Iterator<Item = &'a CId>,
{
Requirements(
cids.cloned()
.map(|col| Requirement {
col,
max_complexity: Complexity::lowest(),
selected: false,
})
.collect(),
)
}

/// Collect columns from the given `Expr` into requirements with
/// the least allowed complexity and unselected by default.
pub fn from_expr(expr: &Expr) -> Requirements {
let cids = CidCollector::collect(expr.clone());
Requirements::from_cids(cids.iter())
}

/// Moves all the elements of `other` into `self`, leaving `other` empty,
/// then return `self` for chainability.
pub fn append(mut self, mut other: Requirements) -> Requirements {
self.0.append(&mut other.0);
self
}

/// Set a maximum complexity to all stored requirements.
pub fn allow_up_to(mut self, max_complexity: Complexity) -> Self {
for r in &mut self.0 {
r.max_complexity = max_complexity;
}
self
}

/// Set a SELECT status to all stored requirements.
pub fn should_select(mut self, selected: bool) -> Self {
for r in &mut self.0 {
r.selected = selected;
}
self
}
}

impl std::fmt::Debug for Requirement {
Expand All @@ -444,100 +496,78 @@ impl std::fmt::Debug for Requirement {
pub(super) fn get_requirements(
transform: &SqlTransform,
following: &HashSet<String>,
) -> Vec<Requirement> {
) -> Requirements {
use SqlTransform::Super;
use Transform::*;

// special case for Aggregate, which contain two difference Complexity-ies
if let Super(Aggregate { partition, compute }) = transform {
let mut r = Vec::new();
r.extend(into_requirements(
partition.clone(),
Complexity::Plain,
false,
));
r.extend(into_requirements(
compute.clone(),
Complexity::Aggregation,
false,
));
return r;
}

// special case for Compute, which contain two difference Complexity-ies
if let Super(Compute(compute)) = transform {
// expr itself
let expr_cids = CidCollector::collect(compute.expr.clone());

let expr_max_complexity = match infer_complexity(compute) {
// plain expressions can be included in anything less complex than Aggregation
Complexity::Plain => Complexity::Aggregation,

// anything more complex can only use included in other plain expressions.
// in other words: complex expressions (aggregation, window functions) cannot
// be defined within other expressions.
_ => Complexity::Plain,
};
let mut requirements = into_requirements(expr_cids, expr_max_complexity, false);

// window
if let Some(window) = &compute.window {
// TODO: what kind of exprs can be in window frame?
// window.frame
match transform {
Super(Transform::Aggregate { partition, compute }) => {
let partition_requirements = Requirements::from_cids(partition.iter());
let compute_requirements =
Requirements::from_cids(compute.iter()).allow_up_to(Complexity::Aggregation);

let mut window_cids = window.partition.clone();
window_cids.extend(window.sort.iter().map(|s| s.column));

requirements.extend(into_requirements(window_cids, Complexity::Plain, false));
partition_requirements.append(compute_requirements)
}

return requirements;
}

// general case: extract cids
let cids = match transform {
Super(Compute(compute)) => CidCollector::collect(compute.expr.clone()),
Super(Filter(expr)) | SqlTransform::Join { filter: expr, .. } => {
CidCollector::collect(expr.clone())
}
// Aggregations require that all selected columns be wrapped in aggregate functions (e.g., SUM, COUNT).
Super(Sort(sorts)) if !following.contains("Aggregate") => {
sorts.iter().map(|s| s.column).collect()
}
Super(Take(rq::Take { range, .. })) => {
let mut cids = Vec::new();
if let Some(e) = &range.start {
cids.extend(CidCollector::collect(e.clone()));
}
if let Some(e) = &range.end {
cids.extend(CidCollector::collect(e.clone()));
Super(Transform::Compute(compute)) => {
let requirements = Requirements::from_expr(&compute.expr).allow_up_to(
match infer_complexity(compute) {
// plain expressions can be included in anything less complex than Aggregation
Complexity::Plain => Complexity::Aggregation,

// anything more complex can only use included in other plain expressions.
// in other words: complex expressions (aggregation, window functions) cannot
// be defined within other expressions.
_ => Complexity::Plain,
},
);

if let Some(window) = &compute.window {
// TODO: what kind of exprs can be in window frame?
// window.frame

let window_cids = window
.partition
.iter()
.chain(window.sort.iter().map(|s| &s.column));

requirements.append(Requirements::from_cids(window_cids))
} else {
requirements
}
cids
}

_ => return Vec::new(),
};

// general case: determine complexity
let (max_complexity, selected) = match transform {
Super(Filter(_)) => (
if !following.contains("Aggregate") {
Super(Transform::Filter(expr)) => {
Requirements::from_expr(expr).allow_up_to(if !following.contains("Aggregate") {
Complexity::Aggregation
} else {
Complexity::Plain
},
false,
),
// we only use SELECTed columns in ORDER BY, so the columns can have high complexity
Super(Sort(_)) => (Complexity::Aggregation, true),
})
}

// LIMIT and OFFSET can use constant expressions which don't need to be SELECTed
Super(Take(_)) => (Complexity::Plain, false),
SqlTransform::Join { .. } => (Complexity::Plain, false),
_ => unreachable!(),
};
// Aggregations require that all selected columns be wrapped in aggregate functions (e.g., SUM, COUNT).
Super(Transform::Sort(sorts)) if !following.contains("Aggregate") => {
Requirements::from_cids(sorts.iter().map(|s| &s.column))
// we only use SELECTed columns in ORDER BY, so the columns can have high complexity
.allow_up_to(Complexity::Aggregation)
.should_select(true)
}

into_requirements(cids, max_complexity, selected)
SqlTransform::DistinctOn(partition) => Requirements::from_cids(partition.iter())
// Partition columns must be selected in order to push compute columns down CTE.
.should_select(true)
// Since there is aggregation anyway, columns can have any complexity
.allow_up_to(Complexity::highest()),

Super(Transform::Take(rq::Take { range, .. })) => [&range.start, &range.end]
.into_iter()
.flatten()
.map(Requirements::from_expr)
.fold(Requirements::default(), Requirements::append),

SqlTransform::Join { filter, .. } => Requirements::from_expr(filter),

_ => Requirements::default(),
}
}

/// Complexity of a column expressions.
Expand All @@ -554,6 +584,10 @@ pub enum Complexity {
}

impl Complexity {
const fn lowest() -> Self {
Self::Plain
}

const fn highest() -> Self {
Self::Aggregation
}
Expand Down
64 changes: 62 additions & 2 deletions prqlc/prqlc/tests/integration/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2518,11 +2518,17 @@ fn test_distinct_on_03() {
derive foo = 1
select foo
"###).unwrap()), @r"
WITH table_0 AS (
WITH table_1 AS (
SELECT
DISTINCT ON (col1) NULL
DISTINCT ON (col1) col1
FROM
tab1
),
table_0 AS (
SELECT
NULL
FROM
table_1
)
SELECT
1 AS foo
Expand Down Expand Up @@ -5669,3 +5675,57 @@ fn test_type_error_placement() {
t
");
}

#[test]
fn test_missing_columns_group_complex_compute() {
// https://github.com/PRQL/prql/issues/5354
// The focus for this tests is on whether the `hire_date` column is available where it's needed.
// Additional `city` derive are there only to trigger the issue.
assert_snapshot!(compile(
r#"prql target:sql.postgres
from employees
derive `year` = s'EXTRACT(year from {`hire_date`})'
derive { `year_label` = f"Year {`year`}" }
derive { `city` = case [ this.`city` == "Calgary" => "A city", true => this.`city` ] }
derive { `city` = case [ this.`city` == "Edmonton" => "Another city", true => this.`city` ] }
group {`year`, `year_label`} (take 1)
select {this.`year_label`}
"#,
)
.unwrap(), @r"
WITH table_0 AS (
SELECT
CONCAT(
'Year ',
EXTRACT(
year
from
hire_date
)
) AS year_label,
EXTRACT(
year
from
hire_date
) AS _expr_0,
CASE
WHEN city = 'Calgary' THEN 'A city'
ELSE city
END AS _expr_1,
city
FROM
employees
),
table_1 AS (
SELECT
DISTINCT ON (_expr_0, year_label) year_label,
_expr_0
FROM
table_0
)
SELECT
year_label
FROM
table_1
");
}
Loading