Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion prqlc/prqlc/src/ir/pl/lineage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::fmt::{Debug, Display, Formatter};
use enum_as_inner::EnumAsInner;
use itertools::{Itertools, Position};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Serialize, Serializer};

use super::Ident;

Expand Down Expand Up @@ -48,10 +48,23 @@ pub enum LineageColumn {
/// All columns (including unknown ones) from an input (i.e. `foo_table.*`)
All {
input_id: usize,

#[serde(serialize_with = "sorted_set")]
except: HashSet<String>,
},
}

pub fn sorted_set<S: Serializer, V: Serialize + Ord>(
value: &HashSet<V>,
serializer: S,
) -> Result<S::Ok, S::Error> {
value
.iter()
.sorted()
.collect::<Vec<_>>()
.serialize(serializer)
}

impl Display for Lineage {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
display_lineage(self, f, false)
Expand Down
39 changes: 23 additions & 16 deletions prqlc/prqlc/src/sql/pq/anchor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub(super) fn split_off_back(
return (None, Vec::new());
}

let mapping_before = compute_positional_mappings(&pipeline);
let mapping_before = compute_positional_mappings(&pipeline, None);

log::debug!("traversing pipeline to obtain columns: {output:?}");

Expand All @@ -100,7 +100,7 @@ pub(super) fn split_off_back(
}

// anchor and record all requirements
let required = get_requirements(&transform, &following_transforms);
let required = get_requirements(&transform, &following_transforms, &inputs_required);
log::debug!(".. transform {} requires {required:?}", transform.as_str(),);
inputs_required = inputs_required.append(required.clone());

Expand Down Expand Up @@ -188,10 +188,12 @@ pub(super) fn split_off_back(
curr_pipeline_rev.reverse();

// This will compare columns for order sensitive transform and correct it in subsequent relation.
let mapping_after = compute_positional_mappings(&curr_pipeline_rev);
for (before, after) in mapping_before.iter().zip(mapping_after.iter()) {
ctx.positional_mapping
.compute_and_store_mapping(before, after);
let mapping_after = compute_positional_mappings(&curr_pipeline_rev, Some(&inputs_required));
for (riid, after) in mapping_after {
if let Some((_, before)) = mapping_before.iter().find(|(r, _)| &riid == r) {
ctx.positional_mapping
.compute_and_store_mapping(before, &after, &riid);
}
}

(remaining_pipeline, curr_pipeline_rev)
Expand Down Expand Up @@ -483,6 +485,14 @@ impl Requirements {
}
self
}

pub fn is_selected(&self, id: &CId) -> bool {
self.0.iter().any(|r| r.selected && &r.col == id)
}

pub fn is_required(&self, id: &CId) -> bool {
self.0.iter().any(|r| &r.col == id)
}
}

impl std::fmt::Debug for Requirement {
Expand All @@ -496,19 +506,14 @@ impl std::fmt::Debug for Requirement {
pub(super) fn get_requirements(
transform: &SqlTransform,
following: &HashSet<String>,
previous_requirements: &Requirements,
) -> Requirements {
use SqlTransform::Super;

match transform {
Super(Transform::Aggregate { partition, compute }) => {
let partition_requirements = Requirements::from_cids(partition.iter());
let compute_requirements =
Requirements::from_cids(compute.iter()).allow_up_to(Complexity::Aggregation);
Super(Transform::Aggregate { partition, .. }) => Requirements::from_cids(partition.iter()),

partition_requirements.append(compute_requirements)
}

Super(Transform::Compute(compute)) => {
Super(Transform::Compute(compute)) if previous_requirements.is_required(&compute.id) => {
let requirements = Requirements::from_expr(&compute.expr).allow_up_to(
match infer_complexity(compute) {
// plain expressions can be included in anything less complex than Aggregation
Expand Down Expand Up @@ -552,9 +557,11 @@ pub(super) fn get_requirements(
.should_select(true)
}

SqlTransform::Sort(sorts) if !following.contains("Aggregate") => {
Requirements::from_cids(sorts.iter().map(|s| &s.column))
}

SqlTransform::DistinctOn(partition) => Requirements::from_cids(partition.iter())
// Partition columns must be selected in order to push compute columns down CTE.
.should_select(true)
// Since there is aggregation anyway, columns can have any complexity
.allow_up_to(Complexity::highest()),

Expand Down
56 changes: 35 additions & 21 deletions prqlc/prqlc/src/sql/pq/positional_mapping.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::collections::HashMap;

use crate::{
ir::rq::{CId, Transform},
sql::{pq::context::RIId, pq_ast::SqlTransform},
ir::rq::{CId, Compute, Transform},
sql::{
pq::{anchor::Requirements, context::RIId},
pq_ast::SqlTransform,
},
};

/// State required to properly handle the transforms that are order sensitive like `Union`.
Expand All @@ -16,29 +19,24 @@ impl PositionalMapper {
/// Remember the mapping for this `RIId` to know what to apply for `apply_positional_mapping`.
pub(crate) fn activate_mapping(&mut self, riid: &RIId) {
self.active_positional_mapping = self.relation_positional_mapping.remove(riid);
log::trace!(
"loading remapping for {riid:?}: {:?}",
self.active_positional_mapping
);
}

/// Reorder or remove columns to make `Union` happy.
pub(crate) fn apply_active_mapping(&mut self, output: Vec<CId>) -> Vec<CId> {
if let Some(mapping) = &self.active_positional_mapping {
let new_output = mapping.iter().map(|idx| output[*idx]).collect();
log::debug!("remapping {output:?} to {new_output:?}");
log::debug!("remapping {output:?} to {new_output:?} via {mapping:?}");
new_output
} else {
output
}
}

pub fn compute_and_store_mapping(
&mut self,
(_, before): &(RIId, Vec<CId>),
(riid, after): &(RIId, Vec<CId>),
) {
if after == before {
log::trace!(".. relation {riid:?} is already correctly mapped: {after:?}");
return;
}

pub fn compute_and_store_mapping(&mut self, before: &[CId], after: &[CId], riid: &RIId) {
let mapping: Vec<_> = after
.iter()
.flat_map(|a| match before.iter().position(|b| b == a) {
Expand All @@ -60,39 +58,55 @@ impl PositionalMapper {
/// Outputs the columns required for position sensitive transforms in the pipeline.
pub fn compute_positional_mappings(
pipeline: &[SqlTransform<RIId, Transform>],
requirements: Option<&Requirements>,
) -> Vec<(RIId, Vec<CId>)> {
let mut constraints = vec![];
let mut columns = vec![];

log::trace!("traversing pipeline to obtain positional mapping:");

// Only process selected columns to avoid surnumerary one
let add_columns = |columns: &mut Vec<CId>, cids: &[CId]| {
if let Some(requirements) = requirements {
columns.extend(cids.iter().filter(|cid| requirements.is_selected(cid)));
} else {
columns.extend_from_slice(cids);
}
};

for transform in pipeline {
match transform {
SqlTransform::Super(s) => match s {
Transform::Compute(compute) => {
if !columns.contains(&compute.id) {
columns.push(compute.id);
Transform::Compute(Compute { id, .. }) => {
if !columns.contains(id) {
add_columns(&mut columns, &[*id]);
}
}
Transform::Select(cids) => {
columns.clear();
columns.extend_from_slice(cids.as_slice());
add_columns(&mut columns, cids);
}
Transform::Aggregate { partition, compute } => {
Transform::Aggregate { compute, .. } => {
columns.clear();
columns.extend_from_slice(partition.as_slice());
columns.extend_from_slice(compute.as_slice());
add_columns(&mut columns, compute);
}
_ => (),
},
SqlTransform::Except { bottom, .. }
| SqlTransform::Intersect { bottom, .. }
| SqlTransform::Union { bottom, .. } => {
constraints.push((*bottom, columns.clone()));
log::trace!(
".. mapping for {}/{bottom:?}: {columns:?}",
transform.as_str()
);
}
_ => (),
}
log::trace!(".. columns after {}: {columns:?}", transform.as_str());
log::trace!(
".. selected columns after {}: {columns:?}",
transform.as_str()
);
}

constraints
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from invoices
derive total = case [total < 10 => total * 2, true => total]
select { customer_id, invoice_id, total }
take 5
append (
from invoice_items
derive unit_price = case [unit_price < 1 => unit_price * 2, true => unit_price]
select { invoice_line_id, invoice_id, unit_price }
take 5
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: prqlc/prqlc/tests/integration/queries.rs
expression: "from invoices\nselect { customer_id, invoice_id, total }\ntake 5\nappend (\n from invoice_items\n select { invoice_line_id, invoice_id, unit_price }\n take 5\n)\nselect { a = customer_id * 2, b = math.round 1 (invoice_id * total) }\n"
expression: "from invoices\nderive total = case [total < 10 => total * 2, true => total]\nselect { customer_id, invoice_id, total }\ntake 5\nappend (\n from invoice_items\n derive unit_price = case [unit_price < 1 => unit_price * 2, true => unit_price]\n select { invoice_line_id, invoice_id, unit_price }\n take 5\n)\nselect { a = customer_id * 2, b = math.round 1 (invoice_id * total) }\n"
input_file: prqlc/prqlc/tests/integration/queries/append_select_compute.prql
---
WITH table_1 AS (
Expand All @@ -10,7 +10,10 @@ WITH table_1 AS (
(
SELECT
invoice_id,
total,
CASE
WHEN total < 10 THEN total * 2
ELSE total
END AS _expr_0,
customer_id
FROM
invoices
Expand All @@ -25,7 +28,10 @@ WITH table_1 AS (
(
SELECT
invoice_id,
unit_price,
CASE
WHEN unit_price < 1 THEN unit_price * 2
ELSE unit_price
END AS unit_price,
invoice_line_id
FROM
invoice_items
Expand All @@ -35,6 +41,6 @@ WITH table_1 AS (
)
SELECT
customer_id * 2 AS a,
ROUND(invoice_id * total, 1) AS b
ROUND(invoice_id * _expr_0, 1) AS b
FROM
table_1
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
---
source: prqlc/prqlc/tests/integration/queries.rs
expression: "from invoices\nselect { customer_id, invoice_id, total }\ntake 5\nappend (\n from invoice_items\n select { invoice_line_id, invoice_id, unit_price }\n take 5\n)\nselect { a = customer_id * 2, b = math.round 1 (invoice_id * total) }\n"
expression: "from invoices\nderive total = case [total < 10 => total * 2, true => total]\nselect { customer_id, invoice_id, total }\ntake 5\nappend (\n from invoice_items\n derive unit_price = case [unit_price < 1 => unit_price * 2, true => unit_price]\n select { invoice_line_id, invoice_id, unit_price }\n take 5\n)\nselect { a = customer_id * 2, b = math.round 1 (invoice_id * total) }\n"
input_file: prqlc/prqlc/tests/integration/queries/append_select_compute.prql
---
--- generic
+++ glaredb
@@ -23,13 +23,13 @@
unit_price,
@@ -29,13 +29,13 @@
END AS unit_price,
invoice_line_id
FROM
invoice_items
Expand All @@ -16,23 +16,26 @@ input_file: prqlc/prqlc/tests/integration/queries/append_select_compute.prql
)
SELECT
customer_id * 2 AS a,
- ROUND(invoice_id * total, 1) AS b
+ ROUND((invoice_id * total)::numeric, 1) AS b
- ROUND(invoice_id * _expr_0, 1) AS b
+ ROUND((invoice_id * _expr_0)::numeric, 1) AS b
FROM
table_1


--- generic
+++ postgres
@@ -1,35 +1,28 @@
@@ -1,41 +1,34 @@
WITH table_1 AS (
- SELECT
- *
- FROM
- (
- SELECT
- invoice_id,
- total,
- CASE
- WHEN total < 10 THEN total * 2
- ELSE total
- END AS _expr_0,
- customer_id
- FROM
- invoices
Expand All @@ -42,7 +45,10 @@ input_file: prqlc/prqlc/tests/integration/queries/append_select_compute.prql
+ (
+ SELECT
+ invoice_id,
+ total,
+ CASE
+ WHEN total < 10 THEN total * 2
+ ELSE total
+ END AS _expr_0,
+ customer_id
+ FROM
+ invoices
Expand All @@ -57,7 +63,10 @@ input_file: prqlc/prqlc/tests/integration/queries/append_select_compute.prql
- (
- SELECT
- invoice_id,
- unit_price,
- CASE
- WHEN unit_price < 1 THEN unit_price * 2
- ELSE unit_price
- END AS unit_price,
- invoice_line_id
- FROM
- invoice_items
Expand All @@ -67,7 +76,10 @@ input_file: prqlc/prqlc/tests/integration/queries/append_select_compute.prql
+ ALL (
+ SELECT
+ invoice_id,
+ unit_price,
+ CASE
+ WHEN unit_price < 1 THEN unit_price * 2
+ ELSE unit_price
+ END AS unit_price,
+ invoice_line_id
+ FROM
+ invoice_items
Expand All @@ -77,7 +89,7 @@ input_file: prqlc/prqlc/tests/integration/queries/append_select_compute.prql
)
SELECT
customer_id * 2 AS a,
- ROUND(invoice_id * total, 1) AS b
+ ROUND((invoice_id * total)::numeric, 1) AS b
- ROUND(invoice_id * _expr_0, 1) AS b
+ ROUND((invoice_id * _expr_0)::numeric, 1) AS b
FROM
table_1
Loading
Loading