Skip to content

Commit 5335f80

Browse files
authored
Add SQL planner support for grouping() aggregate expressions (#2486)
* Add SQL planner support for grouping() aggregate function * complex test with rank and partition * fix window aggregate case * code cleanup
1 parent 8a29ed5 commit 5335f80

10 files changed

Lines changed: 193 additions & 22 deletions

File tree

datafusion/core/src/logical_plan/expr.rs

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use crate::sql::utils::find_columns_referenced_by_expr;
2626
use arrow::datatypes::DataType;
2727
pub use datafusion_common::{Column, ExprSchema};
2828
pub use datafusion_expr::expr_fn::*;
29+
use datafusion_expr::logical_plan::Aggregate;
2930
use datafusion_expr::BuiltinScalarFunction;
3031
pub use datafusion_expr::Expr;
3132
use datafusion_expr::StateTypeFunction;
@@ -136,35 +137,63 @@ pub fn create_udaf(
136137
)
137138
}
138139

140+
/// Find all columns referenced from an aggregate query
141+
fn agg_cols(agg: &Aggregate) -> Result<Vec<Column>> {
142+
Ok(agg
143+
.aggr_expr
144+
.iter()
145+
.chain(&agg.group_expr)
146+
.flat_map(find_columns_referenced_by_expr)
147+
.collect())
148+
}
149+
150+
fn exprlist_to_fields_aggregate(
151+
exprs: &[Expr],
152+
plan: &LogicalPlan,
153+
agg: &Aggregate,
154+
) -> Result<Vec<DFField>> {
155+
let agg_cols = agg_cols(agg)?;
156+
let mut fields = vec![];
157+
for expr in exprs {
158+
match expr {
159+
Expr::Column(c) if agg_cols.iter().any(|x| x == c) => {
160+
// resolve against schema of input to aggregate
161+
fields.push(expr.to_field(agg.input.schema())?);
162+
}
163+
_ => fields.push(expr.to_field(plan.schema())?),
164+
}
165+
}
166+
Ok(fields)
167+
}
168+
139169
/// Create field meta-data from an expression, for use in a result set schema
140170
pub fn exprlist_to_fields<'a>(
141171
expr: impl IntoIterator<Item = &'a Expr>,
142172
plan: &LogicalPlan,
143173
) -> Result<Vec<DFField>> {
144-
match plan {
174+
let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
175+
// when dealing with aggregate plans we cannot simply look in the aggregate output schema
176+
// because it will contain columns representing complex expressions (such a column named
177+
// `#GROUPING(person.state)` so in order to resolve `person.state` in this case we need to
178+
// look at the input to the aggregate instead.
179+
let fields = match plan {
145180
LogicalPlan::Aggregate(agg) => {
146-
let group_expr: Vec<Column> = agg
147-
.group_expr
148-
.iter()
149-
.flat_map(find_columns_referenced_by_expr)
150-
.collect();
151-
let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
152-
let mut fields = vec![];
153-
for expr in &exprs {
154-
match expr {
155-
Expr::Column(c) if group_expr.iter().any(|x| x == c) => {
156-
// resolve against schema of input to aggregate
157-
fields.push(expr.to_field(agg.input.schema())?);
158-
}
159-
_ => fields.push(expr.to_field(plan.schema())?),
160-
}
161-
}
162-
Ok(fields)
163-
}
164-
_ => {
165-
let input_schema = &plan.schema();
166-
expr.into_iter().map(|e| e.to_field(input_schema)).collect()
181+
Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
167182
}
183+
LogicalPlan::Window(window) => match window.input.as_ref() {
184+
LogicalPlan::Aggregate(agg) => {
185+
Some(exprlist_to_fields_aggregate(&exprs, plan, agg))
186+
}
187+
_ => None,
188+
},
189+
_ => None,
190+
};
191+
if let Some(fields) = fields {
192+
fields
193+
} else {
194+
// look for exact match in plan's output schema
195+
let input_schema = &plan.schema();
196+
exprs.iter().map(|e| e.to_field(input_schema)).collect()
168197
}
169198
}
170199

datafusion/core/src/sql/planner.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4694,6 +4694,38 @@ mod tests {
46944694
quick_test(sql, expected);
46954695
}
46964696

4697+
#[tokio::test]
4698+
async fn aggregate_with_rollup_with_grouping() {
4699+
let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), COUNT(*) \
4700+
FROM person GROUP BY id, ROLLUP (state, age)";
4701+
let expected = "Projection: #person.id, #person.state, #person.age, #GROUPING(person.state), #GROUPING(person.age), #GROUPING(person.state) + #GROUPING(person.age), #COUNT(UInt8(1))\
4702+
\n Aggregate: groupBy=[[#person.id, ROLLUP (#person.state, #person.age)]], aggr=[[GROUPING(#person.state), GROUPING(#person.age), COUNT(UInt8(1))]]\
4703+
\n TableScan: person projection=None";
4704+
quick_test(sql, expected);
4705+
}
4706+
4707+
#[tokio::test]
4708+
async fn rank_partition_grouping() {
4709+
let sql = "select
4710+
sum(age) as total_sum,
4711+
state,
4712+
last_name,
4713+
grouping(state) + grouping(last_name) as x,
4714+
rank() over (
4715+
partition by grouping(state) + grouping(last_name),
4716+
case when grouping(last_name) = 0 then state end
4717+
order by sum(age) desc
4718+
) as the_rank
4719+
from
4720+
person
4721+
group by rollup(state, last_name)";
4722+
let expected = "Projection: #SUM(person.age) AS total_sum, #person.state, #person.last_name, #GROUPING(person.state) + #GROUPING(person.last_name) AS x, #RANK() PARTITION BY [#GROUPING(person.state) + #GROUPING(person.last_name), CASE WHEN #GROUPING(person.last_name) = Int64(0) THEN #person.state END] ORDER BY [#SUM(person.age) DESC NULLS FIRST] AS the_rank\
4723+
\n WindowAggr: windowExpr=[[RANK() PARTITION BY [#GROUPING(person.state) + #GROUPING(person.last_name), CASE WHEN #GROUPING(person.last_name) = Int64(0) THEN #person.state END] ORDER BY [#SUM(person.age) DESC NULLS FIRST]]]\
4724+
\n Aggregate: groupBy=[[ROLLUP (#person.state, #person.last_name)]], aggr=[[SUM(#person.age), GROUPING(#person.state), GROUPING(#person.last_name)]]\
4725+
\n TableScan: person projection=None";
4726+
quick_test(sql, expected);
4727+
}
4728+
46974729
#[tokio::test]
46984730
async fn aggregate_with_cube() {
46994731
let sql =

datafusion/expr/src/aggregate_function.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ pub enum AggregateFunction {
8686
ApproxPercentileContWithWeight,
8787
/// ApproxMedian
8888
ApproxMedian,
89+
/// Grouping
90+
Grouping,
8991
}
9092

9193
impl fmt::Display for AggregateFunction {
@@ -121,6 +123,7 @@ impl FromStr for AggregateFunction {
121123
AggregateFunction::ApproxPercentileContWithWeight
122124
}
123125
"approx_median" => AggregateFunction::ApproxMedian,
126+
"grouping" => AggregateFunction::Grouping,
124127
_ => {
125128
return Err(DataFusionError::Plan(format!(
126129
"There is no built-in function named {}",
@@ -173,6 +176,7 @@ pub fn return_type(
173176
Ok(coerced_data_types[0].clone())
174177
}
175178
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
179+
AggregateFunction::Grouping => Ok(DataType::Int32),
176180
}
177181
}
178182

@@ -326,6 +330,7 @@ pub fn coerce_types(
326330
}
327331
Ok(input_types.to_vec())
328332
}
333+
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
329334
}
330335
}
331336

@@ -335,6 +340,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
335340
match fun {
336341
AggregateFunction::Count
337342
| AggregateFunction::ApproxDistinct
343+
| AggregateFunction::Grouping
338344
| AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
339345
AggregateFunction::Min | AggregateFunction::Max => {
340346
let valid = STRINGS

datafusion/physical-expr/src/aggregate/build_in.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ pub fn create_aggregate_expr(
8282
name,
8383
return_type,
8484
)),
85+
(AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new(
86+
coerced_phy_exprs[0].clone(),
87+
name,
88+
return_type,
89+
)),
8590
(AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new(
8691
coerced_phy_exprs[0].clone(),
8792
name,
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Defines physical expressions that can evaluated at runtime during query execution
19+
20+
use std::any::Any;
21+
use std::sync::Arc;
22+
23+
use crate::{AggregateExpr, PhysicalExpr};
24+
use arrow::datatypes::DataType;
25+
use arrow::datatypes::Field;
26+
use datafusion_common::{DataFusionError, Result};
27+
use datafusion_expr::Accumulator;
28+
29+
use crate::expressions::format_state_name;
30+
31+
/// GROUPING aggregate expression
32+
/// Returns the amount of non-null values of the given expression.
33+
#[derive(Debug)]
34+
pub struct Grouping {
35+
name: String,
36+
data_type: DataType,
37+
nullable: bool,
38+
expr: Arc<dyn PhysicalExpr>,
39+
}
40+
41+
impl Grouping {
42+
/// Create a new GROUPING aggregate function.
43+
pub fn new(
44+
expr: Arc<dyn PhysicalExpr>,
45+
name: impl Into<String>,
46+
data_type: DataType,
47+
) -> Self {
48+
Self {
49+
name: name.into(),
50+
expr,
51+
data_type,
52+
nullable: true,
53+
}
54+
}
55+
}
56+
57+
impl AggregateExpr for Grouping {
58+
/// Return a reference to Any that can be used for downcasting
59+
fn as_any(&self) -> &dyn Any {
60+
self
61+
}
62+
63+
fn field(&self) -> Result<Field> {
64+
Ok(Field::new(
65+
&self.name,
66+
self.data_type.clone(),
67+
self.nullable,
68+
))
69+
}
70+
71+
fn state_fields(&self) -> Result<Vec<Field>> {
72+
Ok(vec![Field::new(
73+
&format_state_name(&self.name, "grouping"),
74+
self.data_type.clone(),
75+
true,
76+
)])
77+
}
78+
79+
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
80+
vec![self.expr.clone()]
81+
}
82+
83+
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
84+
Err(DataFusionError::NotImplemented(
85+
"physical plan is not yet implemented for GROUPING aggregate function"
86+
.to_owned(),
87+
))
88+
}
89+
90+
fn name(&self) -> &str {
91+
&self.name
92+
}
93+
}

datafusion/physical-expr/src/aggregate/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pub(crate) mod correlation;
3636
pub(crate) mod count;
3737
pub(crate) mod count_distinct;
3838
pub(crate) mod covariance;
39+
pub(crate) mod grouping;
3940
#[macro_use]
4041
pub(crate) mod min_max;
4142
pub mod build_in;

datafusion/physical-expr/src/expressions/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub use crate::aggregate::correlation::Correlation;
5050
pub use crate::aggregate::count::Count;
5151
pub use crate::aggregate::count_distinct::DistinctCount;
5252
pub use crate::aggregate::covariance::{Covariance, CovariancePop};
53+
pub use crate::aggregate::grouping::Grouping;
5354
pub use crate::aggregate::min_max::{Max, Min};
5455
pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator};
5556
pub use crate::aggregate::stats::StatsType;

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ enum AggregateFunction {
211211
APPROX_PERCENTILE_CONT = 14;
212212
APPROX_MEDIAN=15;
213213
APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
214+
GROUPING = 17;
214215
}
215216

216217
message AggregateExprNode {

datafusion/proto/src/from_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
496496
Self::ApproxPercentileContWithWeight
497497
}
498498
protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian,
499+
protobuf::AggregateFunction::Grouping => Self::Grouping,
499500
}
500501
}
501502
}

datafusion/proto/src/to_proto.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
356356
Self::ApproxPercentileContWithWeight
357357
}
358358
AggregateFunction::ApproxMedian => Self::ApproxMedian,
359+
AggregateFunction::Grouping => Self::Grouping,
359360
}
360361
}
361362
}
@@ -541,6 +542,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
541542
AggregateFunction::ApproxMedian => {
542543
protobuf::AggregateFunction::ApproxMedian
543544
}
545+
AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping,
544546
};
545547

546548
let aggregate_expr = protobuf::AggregateExprNode {

0 commit comments

Comments
 (0)