Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,22 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeMap, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject.{buildCleanedProjectList, canCollapseExpressions}
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType}
import org.apache.spark.sql.util.SchemaUtils._

object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper {
import DataSourceV2Implicits._

def apply(plan: LogicalPlan): LogicalPlan = {
Expand Down Expand Up @@ -88,27 +89,43 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
}

def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
private def collapseProject(plan: LogicalPlan): LogicalPlan = {
val alwaysInline = conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE)
plan transformUp {
case agg @ Aggregate(_, aggregateExpressions, p: Project)
if canCollapseExpressions(aggregateExpressions, p.projectList, alwaysInline) =>
agg.copy(aggregateExpressions = buildCleanedProjectList(
aggregateExpressions, p.projectList))
}
}

def pushDownAggregates(plan: LogicalPlan): LogicalPlan = collapseProject(plan).transform {
// update the scan builder with agg pushdown and return a new plan with agg pushed
case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
child match {
case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
if filters.isEmpty && project.forall(_.deterministic) =>
sHolder.builder match {
case r: SupportsPushDownAggregates =>
val aliasMap = getAliasMap(project)
val newResultExpressions = resultExpressions.map(replaceAliasWithAttr(_, aliasMap))
val newGroupingExpressions = groupingExpressions.map {
case e: NamedExpression => replaceAliasWithAttr(e, aliasMap)
case other => other
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make the code more explicit? We need to clearly show the steps

  1. collapse aggregate and project
  2. remove the alias from aggregate functions and group by expressions (this logic should be put here instead of AliasHelper as this is not a common logic)
  3. push down agg
  4. add back alias for group by expressions only.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the good idea.

val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal)
val aggregates = collectAggregates(newResultExpressions, aggExprToOutputOrdinal)
val normalizedAggregates = DataSourceStrategy.normalizeExprs(
aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
groupingExpressions, sHolder.relation.output)
newGroupingExpressions, sHolder.relation.output)
val translatedAggregates = DataSourceStrategy.translateAggregation(
normalizedAggregates, normalizedGroupingExpressions)
val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = {
val (selectedResultExpressions, selectedAggregates, selectedTranslatedAggregates) = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we rename these?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not confirmed yet.

if (translatedAggregates.isEmpty ||
r.supportCompletePushDown(translatedAggregates.get) ||
translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
(resultExpressions, aggregates, translatedAggregates)
(newResultExpressions, aggregates, translatedAggregates)
} else {
// scalastyle:off
// The data source doesn't support the complete push-down of this aggregation.
Expand Down Expand Up @@ -156,15 +173,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}

if (finalTranslatedAggregates.isEmpty) {
aggNode // return original plan node
} else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) &&
!supportPartialAggPushDown(finalTranslatedAggregates.get)) {
aggNode // return original plan node
if (selectedTranslatedAggregates.isEmpty) {
return plan // return original plan node
} else if (!r.supportCompletePushDown(selectedTranslatedAggregates.get) &&
!supportPartialAggPushDown(selectedTranslatedAggregates.get)) {
return plan // return original plan node
} else {
val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation)
val pushedAggregates = selectedTranslatedAggregates.filter(r.pushAggregation)
if (pushedAggregates.isEmpty) {
aggNode // return original plan node
return plan // return original plan node
} else {
// No need to do column pruning because only the aggregate columns are used as
// DataSourceV2ScanRelation output columns. All the other columns are not
Expand All @@ -182,7 +199,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
// scalastyle:on
val newOutput = scan.readSchema().toAttributes
assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
assert(newOutput.length == groupingExpressions.length + selectedAggregates.length)
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
case (_, b) => b
Expand All @@ -203,8 +220,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
val scanRelation =
DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
val aliasAttrMap = getAttrToAliasMap(aliasMap)
val finalResultExpressions = selectedResultExpressions.map {
case attr: AttributeReference =>
aliasAttrMap.getOrElse(attr, attr)
case other => other
}
if (r.supportCompletePushDown(pushedAggregates.get)) {
val projectExpressions = resultExpressions.map { expr =>
val projectExpressions = finalResultExpressions.map { expr =>
// TODO At present, only push down group by attribute is supported.
// In future, more attribute conversion is extended here. e.g. GetStructField
expr.transform {
Expand Down Expand Up @@ -259,12 +282,31 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}
}
case _ => aggNode
case _ => return plan
}
case _ => aggNode
case _ => return plan
}
}

/**
* Replace all alias, with the aliased attribute.
*/
private def replaceAliasWithAttr(
expr: NamedExpression,
aliasMap: AttributeMap[Alias]): NamedExpression = {
replaceAliasButKeepName(expr, aliasMap).transform {
case Alias(attr: Attribute, _) => attr
}.asInstanceOf[NamedExpression]
}

protected def getAttrToAliasMap(aliasMap: AttributeMap[Alias]): AttributeMap[Alias] = {
val attrToAliasMap = aliasMap.values.toSeq.collect {
case alias @ Alias(originAttr: Attribute, _) =>
(originAttr, alias)
}
AttributeMap(attrToAliasMap)
}

private def collectAggregates(resultExpressions: Seq[NamedExpression],
aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = {
var ordinal = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ trait FileSourceAggregatePushDownSuite
}
}

test("aggregate over alias not push down") {
test("aggregate over alias push down") {
val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
(9, "mno", 7), (2, null, 6))
withDataSourceTable(data, "t") {
Expand All @@ -194,7 +194,7 @@ trait FileSourceAggregatePushDownSuite
query.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregation: []" // aggregate alias not pushed down
"PushedAggregation: [MIN(_1)]"
checkKeywordsExistsInExplain(query, expected_plan_fragment)
}
checkAnswer(query, Seq(Row(-2)))
Expand Down
Loading