Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions prqlc/prqlc/src/sql/pq/anchor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,32 +660,6 @@ impl<'a> CidRedirector<'a> {

fold_column_sorts(&mut redirector, sorts).unwrap()
}

// revert sort columns back to their original pre-split columns
pub fn revert_sorts(
sorts: Vec<ColumnSort<CId>>,
ctx: &'a mut AnchorContext,
) -> Vec<ColumnSort<CId>> {
sorts
.into_iter()
.map(|sort| {
let decl = ctx.column_decls.get(&sort.column).unwrap();
if let ColumnDecl::RelationColumn(riid, cid, _) = decl {
let cid_redirects = &ctx.relation_instances[riid].cid_redirects;
for (source, target) in cid_redirects.iter() {
if target == cid {
log::debug!("reverting {target:?} back to {source:?}");
return ColumnSort {
direction: sort.direction,
column: *source,
};
}
}
}
sort
})
.collect()
}
}

impl RqFold for CidRedirector<'_> {
Expand Down
54 changes: 53 additions & 1 deletion prqlc/prqlc/src/sql/pq/gen_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ impl PqMapper<RIId, pq::RelationExpr, rq::Transform, ()> for TransformCompiler<'
pub(super) fn compile_relation_instance(riid: RIId, ctx: &mut Context) -> Result<pq::RelationExpr> {
ctx.anchor.positional_mapping.activate_mapping(&riid);

let table_ref = &ctx.anchor.relation_instances.get(&riid).unwrap().table_ref;
let rel_instance = &ctx.anchor.relation_instances[&riid];
let nb_redirects = rel_instance.cid_redirects.len();
let table_ref = &rel_instance.table_ref;
let source = table_ref.source;
let decl = ctx.anchor.table_decls.get_mut(&table_ref.source).unwrap();

Expand All @@ -200,6 +202,56 @@ pub(super) fn compile_relation_instance(riid: RIId, ctx: &mut Context) -> Result
}

let relation = compile_relation(sql_relation, ctx)?;

if let pq::SqlRelation::AtomicPipeline(pipeline) = &relation {
// Finding the last select statement of the pipeline
let last_select_columns = pipeline.iter().rev().find_map(|transform| match transform {
pq::SqlTransform::Select(cids) => Some(cids),
_ => None,
});

log::debug!("last select CIds for {riid:?}: {last_select_columns:?}");

// If the pipeline ends with a select, we must recompute its CId redirects
if let Some(cids) = last_select_columns {
// Only recompute the CId redirects if there are exactly as many columns in the
// SELECT as there are CId redirects. This probably means that it is a projecting
// select added by `anchor_split`
if nb_redirects == cids.len() {
log::debug!(
"recomputing cid_redirects for {riid:?}. current redirects: {:?}",
ctx.anchor.relation_instances[&riid].cid_redirects
);
// Inefficient but only way to ensure that the new redirects match the original cids
let new_redirects = cids
.iter()
.zip(&ctx.anchor.relation_instances[&riid].original_cids)
.map(|(new_cid, original_cid)| {
let key_for_value = ctx.anchor.relation_instances[&riid]
.cid_redirects
.iter()
.find_map(|(k, v)| if v == original_cid { Some(k) } else { None })
.unwrap();

(
*new_cid,
ctx.anchor.relation_instances[&riid].cid_redirects[key_for_value],
)
})
.collect();

log::debug!(
"recomputed cid_redirects for {riid:?}. new redirects: {new_redirects:?}",
);

ctx.anchor
.relation_instances
.get_mut(&riid)
.unwrap()
.cid_redirects = new_redirects;
}
}
}
ctx.ctes.push(pq::Cte {
tid: source,
kind: pq::CteKind::Normal(relation),
Expand Down
147 changes: 138 additions & 9 deletions prqlc/prqlc/src/sql/pq/postprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
//!
//! Currently only moves [SqlTransform::Sort]s.

use std::collections::{HashMap, HashSet};
use std::collections::{HashMap, HashSet, VecDeque};

use itertools::Itertools;

use super::anchor::CidRedirector;
use super::ast::*;
use crate::ir::generic::ColumnSort;
use crate::ir::pl::Ident;
use crate::ir::rq::{CId, RqFold, TId};
use crate::ir::rq::{CId, ExprKind, RqFold, TId};
use crate::sql::pq::context::{ColumnDecl, RIId};
use crate::sql::Context;
use crate::Result;

Expand Down Expand Up @@ -41,6 +42,110 @@ struct SortingInference<'a> {
ctx: &'a mut Context,
}

impl SortingInference<'_> {
/// Prepares the last sorting that will be appended to the pipeline of the `SqlQuery` by
/// `fold_sql_query`. It does so by reverting all columns in the sorting to their very first
/// form, and then transforming their value in the final select, while applying
/// renaming/aliasing when possible. This cannot be done directly in `fold_sql_transforms`
/// because renames are not considered to be SQL transforms.
fn alias_last_sorting(&mut self, mut last_sorting: Sorting, final_select: &[CId]) -> Sorting {
log::debug!("unaliasing last sorting: {last_sorting:?}");
let redirects = self
.ctx
.anchor
.relation_instances
.iter()
.map(|(riid, rel_inst)| (riid, &rel_inst.cid_redirects))
.collect::<HashMap<_, _>>();

// a map of column -> alias
let column_aliases = self
.ctx
.anchor
.column_decls
.values()
.filter_map(|col| {
if let ColumnDecl::Compute(compute) = col {
if let ExprKind::ColumnRef(referenced_id) = compute.expr.kind {
Some((referenced_id, compute.id))
} else {
None
}
} else {
None
}
})
.collect::<HashMap<_, _>>();
log::debug!(".. column aliases: {column_aliases:?}");

// column -> list of tables that did a revert
let mut reverts: HashMap<CId, VecDeque<RIId>> = HashMap::new();
log::debug!(".. reverting all columns to their original value");
last_sorting.iter_mut().for_each(|sort| {
let mut riids = VecDeque::new();
let mut changed = true;
while changed {
changed = false;
if let Some(ColumnDecl::RelationColumn(riid, cid, _)) =
self.ctx.anchor.column_decls.get(&sort.column)
{
let cid_redirects = redirects[riid];
for (source, target) in cid_redirects.iter() {
if target == cid {
log::debug!(
".. reverting {target:?} back to {source:?} via redirects of {riid:?}"
);
sort.column = *source;
changed = true;
riids.push_front(*riid);
break;
}
}
}
}
reverts.insert(sort.column, riids);
});
log::debug!(".. done reverting all columns to their original value: {last_sorting:?}");

log::debug!(".. reverting columns forward and aliasing them");
// reverting forward
last_sorting.iter_mut().for_each(|sort| {
let col_reverts = &reverts[&sort.column];
for riid in col_reverts {
if final_select.contains(&sort.column) {
log::debug!(
".. sort column {:?} is in the final select columns, skip reverting",
&sort.column
);
return;
}
// try renaming
if column_aliases.contains_key(&sort.column) {
let alias = column_aliases[&sort.column];
log::debug!("..aliasing {:?} as {alias:?}", &sort.column);
sort.column = alias;
}
// try de-reverting with the target table
let cid_mappings = redirects[riid];
if cid_mappings.contains_key(&sort.column) {
log::debug!(
".. reverting {:?} forward to {:?} via redirects of {riid:?} ({:?})",
&sort.column,
&cid_mappings[&sort.column],
&cid_mappings
);
sort.column = cid_mappings[&sort.column];
}
}
});

log::debug!("aliased and reverted last sorting forward: {last_sorting:?}");

last_sorting
}
}

#[derive(Debug)]
struct CteSorting {
sorting: Sorting,
}
Expand All @@ -50,6 +155,7 @@ impl RqFold for SortingInference<'_> {}
impl PqFold for SortingInference<'_> {
fn fold_sql_query(&mut self, query: SqlQuery) -> Result<SqlQuery> {
let mut ctes = Vec::with_capacity(query.ctes.len());

for cte in query.ctes {
log::debug!("infer_sorts: {0:?}", cte.tid);
let cte = self.fold_cte(cte)?;
Expand All @@ -68,10 +174,38 @@ impl PqFold for SortingInference<'_> {
self.main_relation = true;
let mut main_relation = self.fold_sql_relation(query.main_relation)?;
log::debug!("--== last_sorting {0:?}", self.last_sorting);
let last_sorting = self.last_sorting.drain(..).collect::<Vec<_>>();

// push a sort at the back of the main pipeline
if let SqlRelation::AtomicPipeline(pipeline) = &mut main_relation {
pipeline.push(SqlTransform::Sort(self.last_sorting.drain(..).collect()));
let from_id = pipeline
.iter()
.find_map(|transform| match transform {
SqlTransform::From(rel) => Some(rel.riid),
_ => None,
})
.unwrap();

let final_select = pipeline
.iter()
.rev()
.find_map(|transform| match transform {
SqlTransform::Select(select) => Some(select),
_ => None,
})
.unwrap();
log::debug!("--== final select: {final_select:?}");

let unaliased_last_sorting = self.alias_last_sorting(last_sorting, final_select);
log::debug!("--== unaliased last sorting: {unaliased_last_sorting:?}");
let redirected_last_sorting = CidRedirector::redirect_sorts(
unaliased_last_sorting,
&from_id,
&mut self.ctx.anchor,
);
log::debug!("--== redirected last sorting: {redirected_last_sorting:?}");

pipeline.push(SqlTransform::Sort(redirected_last_sorting));
}

Ok(SqlQuery {
Expand Down Expand Up @@ -143,6 +277,7 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
}
result.push(transform)
}
log::debug!("-- relation sorting {sorting:?}");

if !self.main_relation {
// if this is a CTE, make sure that its SELECT includes the
Expand All @@ -156,12 +291,6 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
select.push(cid);
}
}

// now revert the sort columns so that the output
// sorting reflects the input column cids, needed to
// ensure proper column reference lookup in the final
// steps
sorting = CidRedirector::revert_sorts(sorting, &mut self.ctx.anchor);
}

// remember sorting for this pipeline
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
s"SELECT album_id,title,artist_id FROM albums"
group {artist_id} (aggregate { album_title_count = count this.`title`})
sort {this.artist_id, this.album_title_count}
derive {new_album_count = this.album_title_count}
select {this.artist_id, this.new_album_count}
join side:left ( s"SELECT artist_id,name as artist_name FROM artists" ) (this.artist_id == that.artist_id)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
s"SELECT album_id,title,artist_id FROM albums"
group {artist_id} (aggregate { album_title_count = count this.`title`})
sort {this.artist_id, this.album_title_count}
filter (this.album_title_count) > 10
derive {new_album_count = this.album_title_count}
select {this.artist_id, this.new_album_count}
join side:left ( s"SELECT artist_id,name as artist_name FROM artists" ) (this.artist_id == that.artist_id)
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ SELECT
FROM
table_1
ORDER BY
_expr_0
d1
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
---
source: prqlc/prqlc/tests/integration/queries.rs
expression: "s\"SELECT album_id,title,artist_id FROM albums\"\ngroup {artist_id} (aggregate { album_title_count = count this.`title`})\nsort {this.artist_id, this.album_title_count}\nderive {new_album_count = this.album_title_count}\nselect {this.artist_id, this.new_album_count}\njoin side:left ( s\"SELECT artist_id,name as artist_name FROM artists\" ) (this.artist_id == that.artist_id)\n"
input_file: prqlc/prqlc/tests/integration/queries/group_sort_derive_select_join.prql
---
WITH table_0 AS (
SELECT
album_id,
title,
artist_id
FROM
albums
),
table_4 AS (
SELECT
artist_id,
COUNT(*) AS _expr_0
FROM
table_0
GROUP BY
artist_id
),
table_2 AS (
SELECT
artist_id,
_expr_0 AS new_album_count,
_expr_0
FROM
table_4
),
table_1 AS (
SELECT
artist_id,
name as artist_name
FROM
artists
),
table_3 AS (
SELECT
table_2.artist_id,
table_2.new_album_count,
table_1.artist_id AS _expr_1,
table_1.artist_name,
table_2._expr_0
FROM
table_2
LEFT OUTER JOIN table_1 ON table_2.artist_id = table_1.artist_id
)
SELECT
artist_id,
new_album_count,
_expr_1,
artist_name
FROM
table_3
ORDER BY
artist_id,
new_album_count
Loading
Loading