diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala index 151e9e49e781e..037abc207298a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplans.scala @@ -238,9 +238,7 @@ object MergeSubplans extends Rule[LogicalPlan] { levelFromSubqueries = levelFromSubqueries.max(level + 1) - val mergedOutput = mergeResult.outputMap(planWithReferences.output.head) - val outputIndex = - mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == mergedOutput.exprId) + val outputIndex = mergeResult.outputMap(planWithReferences.output.head) ScalarSubqueryReference( level, mergeResult.mergedPlanIndex, @@ -262,9 +260,7 @@ object MergeSubplans extends Rule[LogicalPlan] { val mergeResult = getPlanMerger(planMergers, level).merge(aggregateWithReferences, false) - val mergedOutput = aggregateWithReferences.output.map(mergeResult.outputMap) - val outputIndices = - mergedOutput.map(a => mergeResult.mergedPlan.plan.output.indexWhere(_.exprId == a.exprId)) + val outputIndices = aggregateWithReferences.output.map(mergeResult.outputMap) val aggregateReference = NonGroupingAggregateReference( level, mergeResult.mergedPlanIndex, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala index 1623166e0a657..fe6b57dd86a50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PlanMerger.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeMap, Expression, If, Literal, NamedExpression, Or} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.internal.SQLConf /** * Result of attempting to merge a plan via [[PlanMerger.merge]]. @@ -31,14 +33,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, Log * - A newly merged plan combining the input with a cached plan * - The original input plan (if no merge was possible) * @param mergedPlanIndex The index of this plan in the PlanMerger's cache. - * @param outputMap Maps attributes from the input plan to corresponding attributes in - * `mergedPlan`. Used to rewrite expressions referencing the original plan - * to reference the merged plan instead. + * @param outputMap Maps attributes of the input plan to their positional index in + * `mergedPlan.plan.output`. The index remains stable across subsequent + * [[PlanMerger.merge]] calls because outputs are only ever appended. */ case class MergeResult( mergedPlan: MergedPlan, mergedPlanIndex: Int, - outputMap: AttributeMap[Attribute]) + outputMap: AttributeMap[Int]) /** * Represents a plan in the PlanMerger's cache. @@ -50,6 +52,19 @@ case class MergeResult( */ case class MergedPlan(plan: LogicalPlan, merged: Boolean) +object PlanMerger { + // Marker tag placed on Filter nodes that were produced by filter propagation. Its presence + // signals that the Filter's condition is already an OR of propagated filter attributes and + // its child Project already contains the corresponding aliases, so a subsequent merge only + // needs to add one new alias for the incoming plan rather than wrapping both sides again. + val MERGED_FILTER_TAG: TreeNodeTag[Unit] = TreeNodeTag("mergedFilter") + + // Global counter for generating unique names for propagated filter attributes across all + // PlanMerger instances. + private[optimizer] val curId = new java.util.concurrent.atomic.AtomicLong() + private[optimizer] def newId: Long = curId.getAndIncrement() +} + /** * A stateful utility for merging identical or similar logical plans to enable query plan reuse. * @@ -67,6 +82,31 @@ case class MergedPlan(plan: LogicalPlan, merged: Boolean) * - [[Filter]]: Requires identical filter conditions * - [[Join]]: Requires identical join type, hints, and conditions * + * When `filterPropagationEnabled` is true, non-grouping [[Aggregate]]s over the same base plan + * with different [[Filter]] conditions can also be merged. The filter conditions are exposed as + * boolean [[Project]] attributes and consumed at the [[Aggregate]] as FILTER clauses. + * When both sides carry a [[Filter]] (the symmetric case), merging broadens the scan to + * OR(f1, f2), which may reduce IO pruning. This path is separately gated by + * `symmetricFilterPropagationEnabled`. + * When plans also differ in intermediate [[Project]] expressions, those are wrapped with + * `If(filterAttr, expr, null)` to avoid computing the expression for rows that do not + * match that side's filter condition. + * + * {{{ + * // Input plans + * Aggregate [sum(a) AS sum_a] Aggregate [max(d) AS max_d] + * +- Filter (a < 1) +- Project [udf(a) AS d] + * +- Scan t +- Filter (a > 1) + * +- Scan t + * + * // Merged plan + * Aggregate [sum(a) FILTER f0 AS sum_a, max(d0) FILTER f1 AS max_d] + * +- Project [a, If(f1, udf(a), null) AS d0, f0, f1] + * +- Filter (f0 OR f1) [MERGED_FILTER_TAG] + * +- Project [a, (a < 1) AS f0, (a > 1) AS f1] + * +- Scan t + * }}} + * * @example * {{{ * val merger = PlanMerger() @@ -76,8 +116,12 @@ case class MergedPlan(plan: LogicalPlan, merged: Boolean) * // result2.outputMap maps plan2's attributes to the merged plan's attributes * }}} */ -class PlanMerger { - val cache = ArrayBuffer.empty[MergedPlan] +class PlanMerger( + filterPropagationEnabled: Boolean = + SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED), + symmetricFilterPropagationEnabled: Boolean = + SQLConf.get.getConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED)) { + val cache = mutable.ArrayBuffer.empty[MergedPlan] /** * Attempts to merge the given plan with cached plans, or adds it to the cache. @@ -97,19 +141,23 @@ class PlanMerger { def merge(plan: LogicalPlan, subqueryPlan: Boolean): MergeResult = { cache.zipWithIndex.collectFirst(Function.unlift { case (mp, i) => - checkIdenticalPlans(plan, mp.plan).map { outputMap => + checkIdenticalPlans(plan, mp.plan).map { _ => // Identical subquery expression plans are not marked as `merged` as the // `ReusedSubqueryExec` rule can handle them without extracting the plans to CTEs. // But, when a non-subquery subplan is identical to a cached plan we need to mark the plan // `merged` and so extract it to a CTE later. - val newMergePlan = MergedPlan(mp.plan, cache(i).merged || !subqueryPlan) - cache(i) = newMergePlan - MergeResult(newMergePlan, i, outputMap) + val newMergedPlan = MergedPlan(mp.plan, cache(i).merged || !subqueryPlan) + cache(i) = newMergedPlan + val outputMap = AttributeMap(plan.output.zipWithIndex) + MergeResult(newMergedPlan, i, outputMap) }.orElse { - tryMergePlans(plan, mp.plan).map { - case (mergedPlan, outputMap) => + tryMergePlans(plan, mp.plan, false).collect { + case TryMergeResult(mergedPlan, npMapping, _, None, None) => val newMergePlan = MergedPlan(mergedPlan, true) cache(i) = newMergePlan + val outputMap = AttributeMap(npMapping.iterator.map { case (origAttr, mergedAttr) => + origAttr -> mergedPlan.output.indexWhere(_.exprId == mergedAttr.exprId) + }.toSeq) MergeResult(newMergePlan, i, outputMap) } } @@ -117,7 +165,7 @@ class PlanMerger { }).getOrElse { val newMergePlan = MergedPlan(plan, false) cache += newMergePlan - val outputMap = AttributeMap(plan.output.map(a => a -> a)) + val outputMap = AttributeMap(plan.output.zipWithIndex) MergeResult(newMergePlan, cache.length - 1, outputMap) } } @@ -141,6 +189,35 @@ class PlanMerger { } } + /** + * Result of a successful [[tryMergePlans]] call. + * + * @param mergedPlan The combined logical plan. + * @param newPlanMapping Mapping from attributes in the new plan to the corresponding + * attributes in the merged plan. Used by parent nodes to remap + * new-plan-side expressions. + * @param cachedPlanMapping Mapping from original cached-plan attributes to their new alias + * attributes when a cached expression was wrapped with an `If`. Used by + * parent nodes to remap cached-plan-side expressions that would + * otherwise reference stale attributes after wrapping. Empty when no + * cached expressions were wrapped. + * @param newPlanFilter A boolean [[Attribute]] in the merged plan that encodes the filter + * condition from the new plan's side, to be applied as an aggregate + * `FILTER (WHERE ...)` clause when the propagation reaches an enclosing + * [[Aggregate]] node. The boolean component is `true` if the attribute was + * freshly aliased and must be appended to enclosing [[Project]] nodes, or + * `false` if it was reused from an existing alias already present in the + * merged plan. `None` when no differing filter was propagated. + * @param cachedPlanFilter Like `newPlanFilter` but for the cached plan's side. Always a freshly + * created alias when present, so no `isNew` flag is needed. + */ + case class TryMergeResult( + mergedPlan: LogicalPlan, + newPlanMapping: AttributeMap[Attribute], + cachedPlanMapping: AttributeMap[Attribute] = AttributeMap.empty, + newPlanFilter: Option[(Attribute, Boolean)] = None, + cachedPlanFilter: Option[Attribute] = None) + /** * Recursively attempts to merge two plans by traversing their tree structures. * @@ -157,83 +234,224 @@ class PlanMerger { * * @param newPlan The plan to merge into the cached plan. * @param cachedPlan The cached plan to merge with. - * @return Some((mergedPlan, outputMap)) if merge succeeds, where: - * - mergedPlan is the combined plan - * - outputMap maps newPlan's attributes to mergedPlan's attributes - * Returns None if plans cannot be merged. + * @return Some([[TryMergeResult]]) if merge succeeds, None if plans cannot be merged. */ private def tryMergePlans( newPlan: LogicalPlan, - cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = { - checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse( + cachedPlan: LogicalPlan, + filterPropagationSupported: Boolean): Option[TryMergeResult] = { + checkIdenticalPlans(newPlan, cachedPlan).map(TryMergeResult(cachedPlan, _)).orElse( (newPlan, cachedPlan) match { case (np: Project, cp: Project) => - tryMergePlans(np.child, cp.child).map { case (mergedChild, outputMap) => - val (mergedProjectList, newOutputMap) = - mergeNamedExpressions(np.projectList, outputMap, cp.projectList) - val mergedPlan = Project(mergedProjectList, mergedChild) - mergedPlan -> newOutputMap + tryMergePlans(np.child, cp.child, filterPropagationSupported).map { + case TryMergeResult(mergedChild, npMapping, cpMapping, npFilter, cpFilter) => + val (mergedProjectList, newNPMapping, newCPMapping) = + mergeNamedExpressions(np.projectList, cp.projectList, npMapping, cpMapping, + npFilter, cpFilter) + TryMergeResult(Project(mergedProjectList, mergedChild), newNPMapping, newCPMapping, + npFilter, cpFilter) } case (np, cp: Project) => - tryMergePlans(np, cp.child).map { case (mergedChild, outputMap) => - val (mergedProjectList, newOutputMap) = - mergeNamedExpressions(np.output, outputMap, cp.projectList) - val mergedPlan = Project(mergedProjectList, mergedChild) - mergedPlan -> newOutputMap + tryMergePlans(np, cp.child, filterPropagationSupported).map { + case TryMergeResult(mergedChild, npMapping, cpMapping, npFilter, cpFilter) => + val (mergedProjectList, newNPMapping, newCPMapping) = + mergeNamedExpressions(np.output, cp.projectList, npMapping, cpMapping, npFilter, + cpFilter) + TryMergeResult(Project(mergedProjectList, mergedChild), newNPMapping, newCPMapping, + npFilter, cpFilter) } case (np: Project, cp) => - tryMergePlans(np.child, cp).map { case (mergedChild, outputMap) => - val (mergedProjectList, newOutputMap) = - mergeNamedExpressions(np.projectList, outputMap, cp.output) - val mergedPlan = Project(mergedProjectList, mergedChild) - mergedPlan -> newOutputMap + tryMergePlans(np.child, cp, filterPropagationSupported).map { + case TryMergeResult(mergedChild, npMapping, cpMapping, npFilter, cpFilter) => + val (mergedProjectList, newNPMapping, newCPMapping) = + mergeNamedExpressions(np.projectList, cp.output, npMapping, cpMapping, npFilter, + cpFilter) + TryMergeResult(Project(mergedProjectList, mergedChild), newNPMapping, newCPMapping, + npFilter, cpFilter) } + case (np: Aggregate, cp: Aggregate) if supportedAggregateMerge(np, cp) => - tryMergePlans(np.child, cp.child).flatMap { case (mergedChild, outputMap) => - val mappedNewGroupingExpression = - np.groupingExpressions.map(mapAttributes(_, outputMap)) - // Order of grouping expression does matter as merging different grouping orders can - // introduce "extra" shuffles/sorts that might not present in all of the original - // subqueries. - if (mappedNewGroupingExpression.map(_.canonicalized) == - cp.groupingExpressions.map(_.canonicalized)) { - val (mergedAggregateExpressions, newOutputMap) = - mergeNamedExpressions(np.aggregateExpressions, outputMap, cp.aggregateExpressions) - val mergedPlan = - Aggregate(cp.groupingExpressions, mergedAggregateExpressions, mergedChild) - Some(mergedPlan -> newOutputMap) - } else { - None - } - } + // Filter propagation into the aggregate is only safe when there is no grouping. + val childFilterPropagationSupported = filterPropagationEnabled && + np.groupingExpressions.isEmpty && cp.groupingExpressions.isEmpty + tryMergePlans(np.child, cp.child, childFilterPropagationSupported).flatMap { + case TryMergeResult(mergedChild, npMapping, cpMapping, None, None) => + val mappedNPGroupingExpression = + np.groupingExpressions.map(mapAttributes(_, npMapping)) + val mappedCPGroupingExpression = + cp.groupingExpressions.map(mapAttributes(_, cpMapping)) + // Order of grouping expression does matter as merging different grouping orders can + // introduce "extra" shuffles/sorts that might not present in all of the original + // subqueries. + if (mappedNPGroupingExpression.map(_.canonicalized) == + mappedCPGroupingExpression.map(_.canonicalized)) { + val (mergedAggregateExpressions, newNPMapping, newCPMapping) = + mergeNamedExpressions(np.aggregateExpressions, cp.aggregateExpressions, npMapping, + cpMapping) + val mergedPlan = + Aggregate(mappedCPGroupingExpression, mergedAggregateExpressions, mergedChild) + Some(TryMergeResult(mergedPlan, newNPMapping, newCPMapping)) + } else { + None + } + case TryMergeResult(mergedChild, npMapping, cpMapping, npFilterOpt, cpFilterOpt) => + // childFilterPropagationSupported guarantees both aggregates have no grouping, so + // the grouping-match check is skipped. + assert(childFilterPropagationSupported) - case (np: Filter, cp: Filter) => - tryMergePlans(np.child, cp.child).flatMap { case (mergedChild, outputMap) => - val mappedNewCondition = mapAttributes(np.condition, outputMap) - // Comparing the canonicalized form is required to ignore different forms of the same - // expression. - if (mappedNewCondition.canonicalized == cp.condition.canonicalized) { - val mergedPlan = cp.withNewChildren(Seq(mergedChild)) - Some(mergedPlan -> outputMap) - } else { - None - } + // Apply each propagated boolean attribute as a FILTER (WHERE ...) clause on the + // corresponding side's aggregate expressions. + // A None filter means the side's aggregate expressions already carry their individual + // FILTER attributes from a previous merge round and should be left unchanged. + // Filter propagation is consumed here and not passed further up. + val filteredNPAggregateExpressions = npFilterOpt.fold(np.aggregateExpressions) { + case (f, _) => applyFilterToAggregateExpressions(np.aggregateExpressions, f) + } + val filteredCPAggregateExpressions = cpFilterOpt.fold(cp.aggregateExpressions)( + applyFilterToAggregateExpressions(cp.aggregateExpressions, _)) + val (mergedAggregateExpressions, newNPMapping, newCPMapping) = + mergeNamedExpressions(filteredNPAggregateExpressions, + filteredCPAggregateExpressions, npMapping, cpMapping) + val mergedPlan = Aggregate(Seq.empty, mergedAggregateExpressions, mergedChild) + Some(TryMergeResult(mergedPlan, newNPMapping, newCPMapping)) } - case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint == cp.hint => - tryMergePlans(np.left, cp.left).flatMap { case (mergedLeft, leftOutputMap) => - tryMergePlans(np.right, cp.right).flatMap { case (mergedRight, rightOutputMap) => - val outputMap = leftOutputMap ++ rightOutputMap - val mappedNewCondition = np.condition.map(mapAttributes(_, outputMap)) + case (np: Filter, cp: Filter) => + tryMergePlans(np.child, cp.child, filterPropagationSupported).flatMap { + case TryMergeResult(mergedChild, npMapping, cpMapping, npFilter, cpFilter) => + val mappedNPCondition = mapAttributes(np.condition, npMapping) + val mappedCPCondition = mapAttributes(cp.condition, cpMapping) // Comparing the canonicalized form is required to ignore different forms of the same - // expression and `AttributeReference.qualifier`s in `cp.condition`. - if (mappedNewCondition.map(_.canonicalized) == cp.condition.map(_.canonicalized)) { - val mergedPlan = cp.withNewChildren(Seq(mergedLeft, mergedRight)) - Some(mergedPlan -> outputMap) + // expression. + if (mappedNPCondition.canonicalized == mappedCPCondition.canonicalized) { + // Identical conditions: the filter node itself adds no new discrimination between + // the two sides, so we keep it unchanged and pass the child's mappings up. + val mergedPlan = Filter(mappedCPCondition, mergedChild) + Some(TryMergeResult(mergedPlan, npMapping, cpMapping, npFilter, cpFilter)) + } else if (filterPropagationSupported && symmetricFilterPropagationEnabled) { + if (cp.getTagValue(PlanMerger.MERGED_FILTER_TAG).isDefined) { + // cp Filter is already a merged filter from a previous round: its condition + // is OR(f0, f1, ...) and its child Project already contains aliases for those + // attributes. Only create a new alias for the np side, and extend the OR + // condition. + val newNPCondition = npFilter.fold(mappedNPCondition) { + case (f, _) => And(f, mappedNPCondition) + } + val childProject = mergedChild match { + case p: Project => p + case other => throw new IllegalStateException( + "Expected Project child under MERGED_FILTER_TAG filter, got " + + s"${other.getClass.getSimpleName}") + } + // If newNPCondition is already aliased in the child Project (e.g. a third + // subplan whose filter matches one from a previous merge round), reuse the + // existing attribute instead of creating a redundant alias. + val existingNPFilter = childProject.projectList.collectFirst { + case a: Alias if a.child.canonicalized == newNPCondition.canonicalized => + a.toAttribute + } + existingNPFilter match { + case Some(reusedFilter) => + Some(TryMergeResult(cp, npMapping, cpMapping, Some((reusedFilter, false)), + None)) + case None => + val newNPFilterAlias = + Alias(newNPCondition, s"propagatedFilter_${PlanMerger.newId}")() + val newNPFilter = newNPFilterAlias.toAttribute + val newProject = childProject.copy( + projectList = childProject.projectList ++ Seq(newNPFilterAlias)) + val newFilter = Filter(Or(mappedCPCondition, newNPFilter), newProject) + newFilter.copyTagsFrom(cp) + Some(TryMergeResult(newFilter, npMapping, cpMapping, + Some((newNPFilter, true)), None)) + } + } else { + // First-time filter propagation: alias both sides' conditions as boolean + // attributes in a new Project below the Filter, and set the Filter condition + // to OR(newNPFilter, newCPFilter). + // Note: the new Project always uses mergedChild as its child (rather than + // flattening into an existing Project below) because mergedChild.output may + // contain previously-propagated filter attributes that newCPCondition + // references. + val newNPCondition = npFilter.fold(mappedNPCondition) { + case (f, _) => And(f, mappedNPCondition) + } + val newCPCondition = cpFilter.fold(mappedCPCondition)(And(_, mappedCPCondition)) + val newNPFilterAlias = + Alias(newNPCondition, s"propagatedFilter_${PlanMerger.newId}")() + val newCPFilterAlias = + Alias(newCPCondition, s"propagatedFilter_${PlanMerger.newId}")() + val newNPFilter = newNPFilterAlias.toAttribute + val newCPFilter = newCPFilterAlias.toAttribute + val project = Project( + mergedChild.output.toList ++ Seq(newNPFilterAlias, newCPFilterAlias), + mergedChild) + val newFilter = Filter(Or(newNPFilter, newCPFilter), project) + newFilter.copyTagsFrom(cp) + newFilter.setTagValue(PlanMerger.MERGED_FILTER_TAG, ()) + Some(TryMergeResult(newFilter, npMapping, cpMapping, Some((newNPFilter, true)), + Some(newCPFilter))) + } } else { None } - } + } + case (np: Filter, cp) if filterPropagationSupported => + tryMergePlans(np.child, cp, filterPropagationSupported).collect { + // If the cp side already propagated a filter from deeper recursion, the merge is + // effectively symmetric (both sides have a filter condition). Abort unless + // symmetricFilterPropagationEnabled. + case TryMergeResult(mergedChild, npMapping, cpMapping, npFilter, cpFilter) + if cpFilter.isEmpty || symmetricFilterPropagationEnabled => + val mappedNPCondition = mapAttributes(np.condition, npMapping) + val newNPCondition = npFilter.fold(mappedNPCondition) { + case (f, _) => And(f, mappedNPCondition) + } + val newNPFilterAlias = + Alias(newNPCondition, s"propagatedFilter_${PlanMerger.newId}")() + val newNPFilter = newNPFilterAlias.toAttribute + val project = Project( + mergedChild.output.toList ++ Seq(newNPFilterAlias) ++ cpFilter.toSeq, + mergedChild) + TryMergeResult(project, npMapping, cpMapping, Some((newNPFilter, true)), cpFilter) + } + case (np, cp: Filter) if filterPropagationSupported => + tryMergePlans(np, cp.child, filterPropagationSupported).collect { + // If the np side already propagated a filter from deeper recursion, the merge is + // effectively symmetric (both sides have a filter condition). Abort unless + // symmetricFilterPropagationEnabled. + case TryMergeResult(mergedChild, npMapping, cpMapping, npFilter, cpFilter) + if npFilter.isEmpty || symmetricFilterPropagationEnabled => + val mappedCPCondition = mapAttributes(cp.condition, cpMapping) + val newCPCondition = cpFilter.fold(mappedCPCondition)(And(_, mappedCPCondition)) + val newCPFilterAlias = + Alias(newCPCondition, s"propagatedFilter_${PlanMerger.newId}")() + val newCPFilter = newCPFilterAlias.toAttribute + val project = Project( + mergedChild.output.toList ++ npFilter.map(_._1).toSeq ++ Seq(newCPFilterAlias), + mergedChild) + TryMergeResult(project, npMapping, cpMapping, npFilter, Some(newCPFilter)) + } + + case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint == cp.hint => + // Filter propagation across joins is not yet supported. + tryMergePlans(np.left, cp.left, false).flatMap { + case TryMergeResult(mergedLeft, leftNPMapping, _, None, None) => + tryMergePlans(np.right, cp.right, false).flatMap { + case TryMergeResult(mergedRight, rightNPMapping, _, None, None) => + val npMapping = leftNPMapping ++ rightNPMapping + val mappedNPCondition = np.condition.map(mapAttributes(_, npMapping)) + // Comparing the canonicalized form is required to ignore different forms of the + // same expression and `AttributeReference.qualifier`s in `cp.condition`. + if (mappedNPCondition.map(_.canonicalized) == cp.condition.map(_.canonicalized)) { + val mergedPlan = cp.withNewChildren(Seq(mergedLeft, mergedRight)) + Some(TryMergeResult(mergedPlan, npMapping)) + } else { + None + } + case _ => None + } + case _ => None } // Otherwise merging is not possible. @@ -247,29 +465,107 @@ class PlanMerger { }.asInstanceOf[T] } - // Applies `outputMap` attribute mapping on attributes of `newExpressions` and merges them into - // `cachedExpressions`. Returns the merged expressions and the attribute mapping from the new to - // the merged version that can be propagated up during merging nodes. + // Remaps attributes of `newPlanExpressions` through `newPlanMapping` and attributes of + // `cachedPlanExpressions` through `cachedPlanMapping`, then merges them into a single + // expression list. + // Returns a triple of: + // 1. The merged expression list + // 2. New plan output map: ne.toAttribute -> merged plan attr (for parent nodes to remap + // new-plan-side expressions) + // 3. Cached plan output map: old wrapped cached attr -> new alias attr (for parent nodes to + // remap cached-plan-side expressions that would otherwise reference stale attributes after + // wrapping). Empty when no cached expressions were wrapped. + // + // When `newPlanFilter`/`cachedPlanFilter` are provided (filter propagation active), non-matching + // expressions from each side are wrapped with `If(filterAttr, expr, null)`. This ensures that a + // non-matching expression from one side evaluates to null for rows that belong to the other side, + // which is safe for aggregate FILTER (WHERE ...) semantics and avoids computing values for + // irrelevant rows. The filter attributes themselves are appended to the merged expression list so + // they remain visible to the enclosing Aggregate that will consume them. A newPlanFilter with + // isNew=false was reused from a previous merge round and is already present in the merged child + // output, so it is not appended again. private def mergeNamedExpressions( - newExpressions: Seq[NamedExpression], - outputMap: AttributeMap[Attribute], - cachedExpressions: Seq[NamedExpression]) = { - val mergedExpressions = ArrayBuffer[NamedExpression](cachedExpressions: _*) - val newOutputMap = AttributeMap(newExpressions.map { ne => - val mapped = mapAttributes(ne, outputMap) + newPlanExpressions: Seq[NamedExpression], + cachedPlanExpressions: Seq[NamedExpression], + newPlanMapping: AttributeMap[Attribute], + cachedPlanMapping: AttributeMap[Attribute] = AttributeMap.empty, + newPlanFilter: Option[(Attribute, Boolean)] = None, + cachedPlanFilter: Option[Attribute] = None) = { + val mergedExpressions = mutable.ArrayBuffer[NamedExpression]( + cachedPlanExpressions.map(mapAttributes(_, cachedPlanMapping)): _*) + val matchedCachedIndices = mutable.HashSet.empty[Int] + val newNPMapping = AttributeMap(newPlanExpressions.map { ne => + val mapped = mapAttributes(ne, newPlanMapping) val withoutAlias = mapped match { case Alias(child, _) => child case e => e } - ne.toAttribute -> mergedExpressions.find { + val foundIdx = mergedExpressions.indexWhere { case Alias(child, _) => child semanticEquals withoutAlias case e => e semanticEquals withoutAlias - }.getOrElse { - mergedExpressions += mapped - mapped - }.toAttribute + } + val resultAttr = if (foundIdx >= 0) { + // Matching expression: both sides compute the same value, no wrapping needed. + matchedCachedIndices += foundIdx + mergedExpressions(foundIdx).toAttribute + } else { + // Non-matching expression from the new plan side: wrap with the new plan filter so it + // is only computed for rows that belong to the new plan side. Plain attribute references + // are not wrapped since reading a column value is free. + val wrappedExpr: NamedExpression = newPlanFilter match { + case Some((f, _)) if !withoutAlias.isInstanceOf[Attribute] => + Alias(If(f, withoutAlias, Literal(null, withoutAlias.dataType)), mapped.name)() + case _ => mapped + } + mergedExpressions += wrappedExpr + wrappedExpr.toAttribute + } + ne.toAttribute -> resultAttr }) - (mergedExpressions.toSeq, newOutputMap) + + // Wrap unmatched cached expressions with the cached plan's filter so they are only computed + // for rows that belong to the cached plan side. Plain attribute references are not wrapped. + // Record each attr rewrite in the cached plan map so ancestor nodes can remap their stale + // references. + val newCPMapping = AttributeMap(cachedPlanFilter.toSeq.flatMap { f => + mergedExpressions.zipWithIndex.flatMap { + case (ce, i) if !matchedCachedIndices.contains(i) => + val withoutAlias = ce match { + case Alias(child, _) => child + case e => e + } + // Plain attribute references are not wrapped: no remapping entry needed. + Option.when(!withoutAlias.isInstanceOf[Attribute]) { + val newAlias = + Alias(If(f, withoutAlias, Literal(null, withoutAlias.dataType)), ce.name)() + mergedExpressions(i) = newAlias + ce.toAttribute -> newAlias.toAttribute + } + case _ => None + } + }) + + newPlanFilter.foreach { + case (f, true) => mergedExpressions += f + case _ => + } + cachedPlanFilter.foreach(mergedExpressions += _) + + (mergedExpressions.toSeq, newNPMapping, newCPMapping) + } + + // Applies filter as a FILTER (WHERE ...) clause to every AggregateExpression in exprs, + // combining with any pre-existing filter on the aggregate via AND. + private def applyFilterToAggregateExpressions( + exprs: Seq[NamedExpression], + filter: Attribute): Seq[NamedExpression] = { + exprs.map(_.transform { + case ae: AggregateExpression => + val combinedFilter = ae.filter.fold[Expression](filter)(And(filter, _)) + val newAE = ae.copy(filter = Some(combinedFilter)) + newAE.copyTagsFrom(ae) + newAE + }.asInstanceOf[NamedExpression]) } // Only allow aggregates of the same implementation because merging different implementations diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d7612d5d78508..9d09683b442d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -6528,6 +6528,30 @@ object SQLConf { .booleanConf .createOptional + val MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED = + buildConf("spark.sql.optimizer.mergeSubplans.filterPropagation.enabled") + .doc("When set to true, subquery plans that differ only in their filter conditions can " + + "be merged by propagating filters up to enclosing non-grouping aggregates.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(true) + + val MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED = + buildConf("spark.sql.optimizer.mergeSubplans.symmetricFilterPropagation.enabled") + .doc("When set to true, two non-grouping aggregate subplans that both have filter " + + "conditions (but with different predicates) can be merged into a single scan using " + + "FILTER (WHERE ...) clauses on each aggregate expression. " + + "Merging two filtered scans broadens the combined filter to OR(f1, f2), which may " + + "reduce IO pruning (e.g. partition or file skipping) compared to the individual " + + "filters. Disabled by default; enable once the behaviour has been validated in your " + + "workload, particularly on heavily partitioned or file-pruned tables. " + + s"Has no effect when ${MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED.key} is false.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val ERROR_MESSAGE_FORMAT = buildConf("spark.sql.error.messageFormat") .doc("When PRETTY, the error message consists of textual representation of error class, " + "message and query context. Stack traces are only shown for internal errors " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala index b368035e278eb..23b3564f08e36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala @@ -19,16 +19,17 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, GetStructField, Literal, ScalarSubquery} -import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, CreateNamedStruct, GetStructField, If, Literal, Or, ScalarSubquery} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf class MergeSubplansSuite extends PlanTest { override def beforeEach(): Unit = { CTERelationDef.curId.set(0) + PlanMerger.curId.set(0) } private object Optimize extends RuleExecutor[LogicalPlan] { @@ -471,11 +472,9 @@ class MergeSubplansSuite extends PlanTest { // supports ObjectHashAggregate val subquery3 = ScalarSubquery(testRelation - .groupBy($"b")(CollectList($"a"). - toAggregateExpression(isDistinct = false).as("collectlist_a"))) + .groupBy($"b")(collectList($"a").as("collectlist_a"))) val subquery4 = ScalarSubquery(testRelation - .groupBy($"b")(CollectSet($"a"). - toAggregateExpression(isDistinct = false).as("collectset_a"))) + .groupBy($"b")(collectSet($"a").as("collectset_a"))) // supports SortAggregate val subquery5 = ScalarSubquery(testRelation.groupBy($"b")(max($"c").as("max_c"))) @@ -501,8 +500,8 @@ class MergeSubplansSuite extends PlanTest { val analyzedHashAggregates = hashAggregates.analyze val objectHashAggregates = testRelation .groupBy($"b")( - CollectList($"a").toAggregateExpression(isDistinct = false).as("collectlist_a"), - CollectSet($"a").toAggregateExpression(isDistinct = false).as("collectset_a")) + collectList($"a").as("collectlist_a"), + collectSet($"a").as("collectset_a")) .select(CreateNamedStruct(Seq( Literal("collectlist_a"), $"collectlist_a", Literal("collectset_a"), $"collectset_a" @@ -721,4 +720,767 @@ class MergeSubplansSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) } + + test("SPARK-40193: Merge non-grouping subqueries with different filter conditions") { + val subquery1 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where($"a" < 1).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val npFilterAlias = Alias($"a" < 1, "propagatedFilter_0")() + val cpFilterAlias = Alias($"a" > 1, "propagatedFilter_1")() + val npFilter = npFilterAlias.toAttribute + val cpFilter = cpFilterAlias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(npFilterAlias, cpFilterAlias): _*) + .where(Or(npFilter, cpFilter)) + .groupBy()( + max($"a", Some(cpFilter)).as("max_a"), + min($"a", Some(npFilter)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Merge three non-grouping subqueries with different filter conditions") { + val subquery1 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where($"a" < 1).groupBy()(min($"a").as("min_a"))) + val subquery3 = ScalarSubquery(testRelation.where($"a" === 1).groupBy()(sum($"a").as("sum_a"))) + val originalQuery = testRelation.select(subquery1, subquery2, subquery3) + + // Step 1: subquery1 (cp) and subquery2 (np) merge: + // f0 = Alias(a < 1, "propagatedFilter_0") -- np / min + // f1 = Alias(a > 1, "propagatedFilter_1") -- cp / max + // -> Project([a,b,c, f0Alias, f1Alias], testRelation) + // -> Filter(OR(f0, f1), above) [tagged] + // propagates (Some(f0), Some(f1)) upward + // + // Step 2: subquery3 (np) merges with merged(1,2) (cp). The cp Filter is tagged, so only a + // new np alias is created and flattened into the existing Project (no nested Projects): + // f2 = Alias(a === 1, "propagatedFilter_2") -- np / sum + // -> Project([a,b,c, f0Alias, f1Alias, f2Alias], testRelation) + // -> Filter(OR(OR(f0, f1), f2), above) [tagged] + // propagates (Some(f2), None) upward + // + // Aggregate: cp agg expressions already carry their FILTERs from step 1 and are unchanged. + // max(a) FILTER f1 -- a > 1 + // min(a) FILTER f0 -- a < 1 + // sum(a) FILTER f2 -- a === 1 + val npFilter0Alias = Alias($"a" < 1, "propagatedFilter_0")() + val cpFilter0Alias = Alias($"a" > 1, "propagatedFilter_1")() + val npFilter0 = npFilter0Alias.toAttribute + val cpFilter0 = cpFilter0Alias.toAttribute + val npFilter1Alias = Alias($"a" === 1, "propagatedFilter_2")() + val npFilter1 = npFilter1Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(npFilter0Alias, cpFilter0Alias, npFilter1Alias): _*) + .where(Or(Or(npFilter0, cpFilter0), npFilter1)) + .groupBy()( + max($"a", Some(cpFilter0)).as("max_a"), + min($"a", Some(npFilter0)).as("min_a"), + sum($"a", Some(npFilter1)).as("sum_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a", + Literal("sum_a"), $"sum_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1), + extractorExpression(0, analyzedMergedSubquery.output, 2)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Merge three non-grouping subqueries where the third has the same filter " + + "condition as the first") { + val subquery1 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where($"a" < 1).groupBy()(min($"a").as("min_a"))) + val subquery3 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(sum($"a").as("sum_a"))) + val originalQuery = testRelation.select(subquery1, subquery2, subquery3) + + // Step 1: subquery1 (cp) and subquery2 (np) merge as usual: + // f0 = Alias(a < 1, "propagatedFilter_0") -- np / min + // f1 = Alias(a > 1, "propagatedFilter_1") -- cp / max + // -> Project([a,b,c, f0Alias, f1Alias], testRelation) + // -> Filter(OR(f0, f1), above) [tagged] + // + // Step 2: subquery3 (np, condition a > 1) merges with merged(1,2) (cp). The cp Filter is + // tagged and (a > 1) is already aliased as f1 in the child Project, so f1 is reused and no + // new alias or extended OR condition is created. Only sum(a) FILTER f1 is added to the agg. + val f0Alias = Alias($"a" < 1, "propagatedFilter_0")() + val f1Alias = Alias($"a" > 1, "propagatedFilter_1")() + val f0 = f0Alias.toAttribute + val f1 = f1Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias, f1Alias): _*) + .where(Or(f0, f1)) + .groupBy()( + max($"a", Some(f1)).as("max_a"), + min($"a", Some(f0)).as("min_a"), + sum($"a", Some(f1)).as("sum_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a", + Literal("sum_a"), $"sum_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1), + extractorExpression(0, analyzedMergedSubquery.output, 2)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Do not merge non-grouping subqueries with different filter conditions when " + + "disabled") { + withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_ENABLED.key -> "false") { + val subquery1 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where($"a" < 1).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-40193: Do not merge non-grouping subqueries with different filter conditions on " + + "both sides when symmetric filter propagation is disabled") { + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "false") { + val subquery1 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where($"a" < 1).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-40193: Merge non-grouping aggregates with different filter conditions") { + val agg1 = testRelation.where($"a" > 1).groupBy()(max($"a").as("max_a")) + val agg2 = testRelation.where($"a" < 1).groupBy()(min($"a").as("min_a")) + val originalQuery = agg1.join(agg2) + + val npFilterAlias = Alias($"a" < 1, "propagatedFilter_0")() + val cpFilterAlias = Alias($"a" > 1, "propagatedFilter_1")() + val npFilter = npFilterAlias.toAttribute + val cpFilter = cpFilterAlias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(npFilterAlias, cpFilterAlias): _*) + .where(Or(npFilter, cpFilter)) + .groupBy()( + max($"a", Some(cpFilter)).as("max_a"), + min($"a", Some(npFilter)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 0, "max_a")) + .join( + OneRowRelation() + .select(extractorExpression(0, analyzedMergedSubquery.output, 1, "min_a"))), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Do not merge grouping aggregates with different filter conditions") { + val subquery1 = + ScalarSubquery(testRelation.where($"a" > 1).groupBy($"b")(max($"a").as("max_a"))) + val subquery2 = + ScalarSubquery(testRelation.where($"a" < 1).groupBy($"b")(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), originalQuery.analyze) + } + } + + test("SPARK-40193: Merge non-grouping subqueries where only the new plan has a filter") { + val subquery1 = ScalarSubquery(testRelation.groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where($"a" < 1).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val npFilterAlias = Alias($"a" < 1, "propagatedFilter_0")() + val npFilter = npFilterAlias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(npFilterAlias): _*) + .groupBy()( + max($"a").as("max_a"), + min($"a", Some(npFilter)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("SPARK-40193: Merge non-grouping subqueries where only the cached plan has a filter") { + val subquery1 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val cpFilterAlias = Alias($"a" > 1, "propagatedFilter_0")() + val cpFilter = cpFilterAlias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(cpFilterAlias): _*) + .groupBy()( + max($"a", Some(cpFilter)).as("max_a"), + min($"a").as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("SPARK-40193: Merge non-grouping subqueries with multiple stacked filter conditions") { + val subquery1 = + ScalarSubquery(testRelation.where($"a" > 1).where($"b" > 2).groupBy()(max($"a").as("max_a"))) + val subquery2 = + ScalarSubquery(testRelation.where($"a" < 1).where($"b" < 2).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + // Merge traversal (inner-to-outer): + // + // Inner level - (np: Filter(a < 1), cp: Filter(a > 1)): + // f0 = Alias(a < 1, "propagatedFilter_0") -- np / min + // f1 = Alias(a > 1, "propagatedFilter_1") -- cp / max + // -> Project([a,b,c, f0_alias, f1_alias], testRelation) + // -> Filter(OR(f0, f1), above) [tagged] + // propagates (Some(f0), Some(f1)) upward + // + // Outer level - (np: Filter(b < 2), cp: Filter(b > 2)): + // f2 = Alias(AND(f0, b < 2), "propagatedFilter_2") -- np + // f3 = Alias(AND(f1, b > 2), "propagatedFilter_3") -- cp + // -> Project([a,b,c, f0, f1, f2_alias, f3_alias], innerFilter) + // -> Filter(OR(f2, f3), above) [tagged] + // propagates (Some(f2), Some(f3)) upward + // + // Aggregate consumes f2/f3 as FILTER clauses: + // max(a) FILTER f3 -- AND(a > 1, b > 2) + // min(a) FILTER f2 -- AND(a < 1, b < 2) + val f0Alias = Alias($"a" < 1, "propagatedFilter_0")() + val f1Alias = Alias($"a" > 1, "propagatedFilter_1")() + val f0 = f0Alias.toAttribute + val f1 = f1Alias.toAttribute + val innerProject = testRelation.select(testRelation.output ++ Seq(f0Alias, f1Alias): _*) + val innerFilter = innerProject.where(Or(f0, f1)) + val f2Alias = Alias(And(f0, $"b" < 2), "propagatedFilter_2")() + val f3Alias = Alias(And(f1, $"b" > 2), "propagatedFilter_3")() + val f2 = f2Alias.toAttribute + val f3 = f3Alias.toAttribute + val mergedSubquery = innerFilter + .select(innerFilter.output ++ Seq(f2Alias, f3Alias): _*) + .where(Or(f2, f3)) + .groupBy()( + max($"a", Some(f3)).as("max_a"), + min($"a", Some(f2)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Merge non-grouping subqueries where the new plan has more filter layers") { + val subquery1 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(max($"a").as("max_a"))) + val subquery2 = + ScalarSubquery(testRelation.where($"a" < 1).where($"b" < 2).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + // Merge traversal (inner-to-outer): + // + // Inner level - (np: Filter(a < 1), cp: testRelation): + // cp has no filter -> (np: Filter, cp) case. No Filter node needed. + // f0 = Alias(a < 1, "propagatedFilter_0") + // -> Project([a, b, c, f0Alias], testRelation) + // propagates (Some(f0), None) upward + // + // Outer level - (np: Filter(b < 2), cp: Filter(a > 1)): + // Both are Filters. Child result has (npFilter=Some(f0), cpFilter=None). + // f1 = Alias(AND(f0, b < 2), "propagatedFilter_1") -- np combined condition + // f2 = Alias(a > 1, "propagatedFilter_2") -- cp condition + // -> Project([a, b, c, f0, f1Alias, f2Alias], innerProject) + // -> Filter(OR(f1, f2), above) [tagged] + // propagates (Some(f1), Some(f2)) upward + // + // Aggregate: + // max(a) FILTER f2 -- cp: a > 1 + // min(a) FILTER f1 -- np: a < 1 AND b < 2 + val f0Alias = Alias($"a" < 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val innerProject = testRelation.select(testRelation.output ++ Seq(f0Alias): _*) + val f1Alias = Alias(And(f0, $"b" < 2), "propagatedFilter_1")() + val f2Alias = Alias($"a" > 1, "propagatedFilter_2")() + val f1 = f1Alias.toAttribute + val f2 = f2Alias.toAttribute + val mergedSubquery = innerProject + .select(innerProject.output ++ Seq(f1Alias, f2Alias): _*) + .where(Or(f1, f2)) + .groupBy()( + max($"a", Some(f2)).as("max_a"), + min($"a", Some(f1)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Merge non-grouping subqueries where the cached plan has more filter layers") { + val subquery1 = + ScalarSubquery(testRelation.where($"a" > 1).where($"b" > 2).groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where($"a" < 1).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + // Merge traversal (inner-to-outer): + // + // Inner level - (np: testRelation, cp: Filter(a > 1)): + // np has no filter -> (np, cp: Filter) case. No Filter node needed. + // f0 = Alias(a > 1, "propagatedFilter_0") + // -> Project([a, b, c, f0Alias], testRelation) + // propagates (None, Some(f0)) upward + // + // Outer level - (np: Filter(a < 1), cp: Filter(b > 2)): + // Both are Filters. Child result has (npFilter=None, cpFilter=Some(f0)). + // f1 = Alias(a < 1, "propagatedFilter_1") -- np condition + // f2 = Alias(AND(f0, b > 2), "propagatedFilter_2") -- cp combined condition + // -> Project([a, b, c, f0, f1Alias, f2Alias], innerProject) + // -> Filter(OR(f1, f2), above) [tagged] + // propagates (Some(f1), Some(f2)) upward + // + // Aggregate: + // max(a) FILTER f2 -- cp: a > 1 AND b > 2 + // min(a) FILTER f1 -- np: a < 1 + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val innerProject = testRelation.select(testRelation.output ++ Seq(f0Alias): _*) + val f1Alias = Alias($"a" < 1, "propagatedFilter_1")() + val f2Alias = Alias(And(f0, $"b" > 2), "propagatedFilter_2")() + val f1 = f1Alias.toAttribute + val f2 = f2Alias.toAttribute + val mergedSubquery = innerProject + .select(innerProject.output ++ Seq(f1Alias, f2Alias): _*) + .where(Or(f1, f2)) + .groupBy()( + max($"a", Some(f2)).as("max_a"), + min($"a", Some(f1)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Merge non-grouping subqueries with equal outer stacked filter") { + val subquery1 = + ScalarSubquery(testRelation.where($"a" > 1).where($"b" > 2).groupBy()(max($"a").as("max_a"))) + val subquery2 = + ScalarSubquery(testRelation.where($"a" < 1).where($"b" > 2).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + // Merge traversal (inner-to-outer): + // + // Inner level - (np: Filter(a < 1), cp: Filter(a > 1)): + // Different conditions -> first-time filter propagation. + // f0 = Alias(a < 1, "propagatedFilter_0") -- np + // f1 = Alias(a > 1, "propagatedFilter_1") -- cp + // -> Project([a, b, c, f0Alias, f1Alias], testRelation) + // -> Filter(OR(f0, f1)) [tagged] + // propagates (Some(f0), Some(f1)) upward + // + // Outer level - (np: Filter(b > 2), cp: Filter(b > 2)): + // Equal conditions -> Filter(b > 2, ...) passes filter attrs through. + // propagates (Some(f0), Some(f1)) unchanged + // + // Aggregate: + // max(a) FILTER f1 -- cp: a > 1 (plus the outer b > 2 applied to all rows) + // min(a) FILTER f0 -- np: a < 1 (plus the outer b > 2 applied to all rows) + val f0Alias = Alias($"a" < 1, "propagatedFilter_0")() + val f1Alias = Alias($"a" > 1, "propagatedFilter_1")() + val f0 = f0Alias.toAttribute + val f1 = f1Alias.toAttribute + val innerProject = testRelation.select(testRelation.output ++ Seq(f0Alias, f1Alias): _*) + val innerFilter = innerProject.where(Or(f0, f1)) + val mergedSubquery = innerFilter + .where($"b" > 2) + .groupBy()( + max($"a", Some(f1)).as("max_a"), + min($"a", Some(f0)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Merge non-grouping subqueries with equal inner stacked filter") { + val subquery1 = + ScalarSubquery(testRelation.where($"a" > 1).where($"b" > 2).groupBy()(max($"a").as("max_a"))) + val subquery2 = + ScalarSubquery(testRelation.where($"a" > 1).where($"b" < 2).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + // Merge traversal (inner-to-outer): + // + // Inner level - (np: Filter(a > 1), cp: Filter(a > 1)): + // checkIdenticalPlans matches -> no filter propagation needed. + // -> Filter(a > 1, testRelation) (shared, unchanged) + // propagates (None, None) upward + // + // Outer level - (np: Filter(b < 2), cp: Filter(b > 2)): + // Different conditions -> first-time filter propagation. + // f0 = Alias(b < 2, "propagatedFilter_0") -- np + // f1 = Alias(b > 2, "propagatedFilter_1") -- cp + // -> Project([a, b, c, f0Alias, f1Alias], Filter(a > 1, testRelation)) + // -> Filter(OR(f0, f1)) [tagged] + // propagates (Some(f0), Some(f1)) upward + // + // Aggregate: + // max(a) FILTER f1 -- cp: a > 1 AND b > 2 + // min(a) FILTER f0 -- np: a > 1 AND b < 2 + val f0Alias = Alias($"b" < 2, "propagatedFilter_0")() + val f1Alias = Alias($"b" > 2, "propagatedFilter_1")() + val f0 = f0Alias.toAttribute + val f1 = f1Alias.toAttribute + val innerFilter = testRelation.where($"a" > 1) + val mergedSubquery = innerFilter + .select(innerFilter.output ++ Seq(f0Alias, f1Alias): _*) + .where(Or(f0, f1)) + .groupBy()( + max($"a", Some(f1)).as("max_a"), + min($"a", Some(f0)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Merge non-grouping subqueries where the new plan has an extra inner filter " + + "below a shared outer filter") { + val subquery1 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(max($"a").as("max_a"))) + val subquery2 = + ScalarSubquery(testRelation.where($"b" < 2).where($"a" > 1).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + // Merge traversal (inner-to-outer): + // + // Inner level - (np: Filter(b < 2), cp: testRelation): + // cp has no filter -> (np: Filter, cp) case. No Filter node needed. + // f0 = Alias(b < 2, "propagatedFilter_0") + // -> Project([a, b, c, f0Alias], testRelation) + // propagates (Some(f0), None) upward + // + // Outer level - (np: Filter(a > 1), cp: Filter(a > 1)): + // Equal conditions -> just wraps with Filter(a > 1, ...) and passes filter attrs through. + // propagates (Some(f0), None) unchanged + // + // Aggregate: + // max(a) unfiltered -- cp: all rows where a > 1 (from outer Filter) + // min(a) FILTER f0 -- np: rows where a > 1 AND b < 2 + val f0Alias = Alias($"b" < 2, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val innerProject = testRelation.select(testRelation.output ++ Seq(f0Alias): _*) + val mergedSubquery = innerProject + .where($"a" > 1) + .groupBy()( + max($"a").as("max_a"), + min($"a", Some(f0)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("SPARK-40193: Merge non-grouping subqueries where the cached plan has an extra inner " + + "filter below a shared outer filter") { + val subquery1 = + ScalarSubquery(testRelation.where($"b" < 2).where($"a" > 1).groupBy()(max($"a").as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where($"a" > 1).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + // Merge traversal (inner-to-outer): + // + // Inner level - (np: testRelation, cp: Filter(b < 2)): + // np has no filter -> (np, cp: Filter) case. No Filter node needed. + // f0 = Alias(b < 2, "propagatedFilter_0") + // -> Project([a, b, c, f0Alias], testRelation) + // propagates (None, Some(f0)) upward + // + // Outer level - (np: Filter(a > 1), cp: Filter(a > 1)): + // Equal conditions -> just wraps with Filter(a > 1, ...) and passes filter attrs through. + // propagates (None, Some(f0)) unchanged + // + // Aggregate: + // max(a) FILTER f0 -- cp: rows where a > 1 AND b < 2 + // min(a) unfiltered -- np: all rows where a > 1 (from outer Filter) + val f0Alias = Alias($"b" < 2, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val innerProject = testRelation.select(testRelation.output ++ Seq(f0Alias): _*) + val mergedSubquery = innerProject + .where($"a" > 1) + .groupBy()( + max($"a", Some(f0)).as("max_a"), + min($"a").as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("SPARK-40193: Merge non-grouping subqueries with equal conditions in reversed filter " + + "order") { + val subquery1 = + ScalarSubquery(testRelation.where($"a" > 1).where($"b" > 2).groupBy()(max($"a").as("max_a"))) + val subquery2 = + ScalarSubquery(testRelation.where($"b" > 2).where($"a" > 1).groupBy()(min($"a").as("min_a"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + // Merge traversal (inner-to-outer): + // + // Because the conditions are in opposite order, each pair of Filter nodes has different + // conditions and filter propagation is triggered at both levels, producing 4 filter + // attributes in total even though both sides ultimately encode a > 1 AND b > 2. + // + // Inner level - (np: Filter(b > 2), cp: Filter(a > 1)): + // f0 = Alias(b > 2, "propagatedFilter_0") -- np inner condition + // f1 = Alias(a > 1, "propagatedFilter_1") -- cp inner condition + // -> Project([a, b, c, f0Alias, f1Alias], testRelation) + // -> Filter(OR(f0, f1)) [tagged] + // propagates (Some(f0), Some(f1)) upward + // + // Outer level - (np: Filter(a > 1), cp: Filter(b > 2)): + // f2 = Alias(AND(f0, a > 1), "propagatedFilter_2") -- np: b > 2 AND a > 1 + // f3 = Alias(AND(f1, b > 2), "propagatedFilter_3") -- cp: a > 1 AND b > 2 + // -> Project([a, b, c, f0, f1, f2Alias, f3Alias], innerFilter) + // -> Filter(OR(f2, f3)) [tagged] + // propagates (Some(f2), Some(f3)) upward + // + // Aggregate: + // max(a) FILTER f3 -- cp: a > 1 AND b > 2 + // min(a) FILTER f2 -- np: b > 2 AND a > 1 (same predicate, different representation) + val f0Alias = Alias($"b" > 2, "propagatedFilter_0")() + val f1Alias = Alias($"a" > 1, "propagatedFilter_1")() + val f0 = f0Alias.toAttribute + val f1 = f1Alias.toAttribute + val innerProject = testRelation.select(testRelation.output ++ Seq(f0Alias, f1Alias): _*) + val innerFilter = innerProject.where(Or(f0, f1)) + val f2Alias = Alias(And(f0, $"a" > 1), "propagatedFilter_2")() + val f3Alias = Alias(And(f1, $"b" > 2), "propagatedFilter_3")() + val f2 = f2Alias.toAttribute + val f3 = f3Alias.toAttribute + val mergedSubquery = innerFilter + .select(innerFilter.output ++ Seq(f2Alias, f3Alias): _*) + .where(Or(f2, f3)) + .groupBy()( + max($"a", Some(f3)).as("max_a"), + min($"a", Some(f2)).as("min_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("min_a"), $"min_a" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Merge non-grouping subqueries with distinct aggregate and different " + + "filter conditions") { + val subquery1 = + ScalarSubquery(testRelation.where($"a" > 1).groupBy()(countDistinct($"a").as("cnt1"))) + val subquery2 = + ScalarSubquery(testRelation.where($"a" < 1).groupBy()(countDistinct($"a").as("cnt2"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val npFilterAlias = Alias($"a" < 1, "propagatedFilter_0")() + val cpFilterAlias = Alias($"a" > 1, "propagatedFilter_1")() + val npFilter = npFilterAlias.toAttribute + val cpFilter = cpFilterAlias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(npFilterAlias, cpFilterAlias): _*) + .where(Or(npFilter, cpFilter)) + .groupBy()( + countDistinctWithFilter(cpFilter, $"a").as("cnt1"), + countDistinctWithFilter(npFilter, $"a").as("cnt2")) + .select(CreateNamedStruct(Seq( + Literal("cnt1"), $"cnt1", + Literal("cnt2"), $"cnt2" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + withSQLConf(SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + } + + test("SPARK-40193: Merge non-grouping subqueries with If-wrapped computed Project expression") { + val subquery1 = ScalarSubquery(testRelation.groupBy()(sum($"a").as("sum_a"))) + val subquery2 = ScalarSubquery( + testRelation.where($"a" > 1).select(($"a" + 1).as("d")).groupBy()(max($"d").as("max_d"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val dIfAlias = + Alias(If(f0, $"a" + 1, Literal(null, testRelation.output.head.dataType)), "d")() + val d = dIfAlias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias): _*) + .select(testRelation.output ++ Seq(dIfAlias, f0): _*) + .groupBy()( + sum($"a").as("sum_a"), + max(d, Some(f0)).as("max_d")) + .select(CreateNamedStruct(Seq( + Literal("sum_a"), $"sum_a", + Literal("max_d"), $"max_d" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("SPARK-40193: Merge non-grouping subqueries where one aggregate already carries a " + + "FILTER clause") { + val subquery1 = ScalarSubquery(testRelation.groupBy()(max($"a").as("max_a"))) + val subquery2 = + ScalarSubquery(testRelation.where($"a" > 1).groupBy()(count($"a", Some($"b" > 0)).as("cnt"))) + val originalQuery = testRelation.select(subquery1, subquery2) + + val f0Alias = Alias($"a" > 1, "propagatedFilter_0")() + val f0 = f0Alias.toAttribute + val mergedSubquery = testRelation + .select(testRelation.output ++ Seq(f0Alias): _*) + .groupBy()( + max($"a").as("max_a"), + count($"a", Some(And(f0, $"b" > 0))).as("cnt")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), $"max_a", + Literal("cnt"), $"cnt" + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + val correctAnswer = WithCTE( + testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala index b7557b42702e8..e86f782f361df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanMergeSuite.scala @@ -339,4 +339,61 @@ class PlanMergeSuite extends QueryTest Row(8, 6)) } } + + test("SPARK-40193: Merge non-grouping scalar subqueries with different filter conditions") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.MERGE_SUBPLANS_SYMMETRIC_FILTER_PROPAGATION_ENABLED.key -> "true") { + val df = sql( + """ + |SELECT + | (SELECT sum(key) FROM testData WHERE key > 50), + | (SELECT sum(key) FROM testData WHERE key <= 50) + """.stripMargin) + + checkAnswer(df, Row(3775, 1275) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 1, + "Missing or unexpected ReusedSubqueryExec in the plan") + } + } + } + + test("SPARK-40193: Merge non-grouping scalar subqueries where only one has a filter") { + Seq(false, true).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + // ObjectSerializerPruning produces different scan shapes depending on whether a Filter is + // present. Disabling the rule makes both scans identical so PlanMerger can merge them. + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning") { + val df = sql( + """ + |SELECT + | (SELECT sum(key) FROM testData), + | (SELECT sum(key) FROM testData WHERE key > 50) + """.stripMargin) + + checkAnswer(df, Row(5050, 3775) :: Nil) + + val plan = df.queryExecution.executedPlan + val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id } + val reusedSubqueryIds = collectWithSubqueries(plan) { + case rs: ReusedSubqueryExec => rs.child.id + } + + assert(subqueryIds.size == 1, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 1, + "Missing or unexpected ReusedSubqueryExec in the plan") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala index e4b5e10f7c3be..6cd49948630da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala @@ -80,6 +80,7 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { // Do not match `id=#123` like ids as those are actually plan ids in `SubqueryExec` nodes. private val exprIdRegexp = "(?(?(plan_id=|id=#))\\d+".r + private val propagatedFilterIdRegex = "(?propagatedFilter_)\\d+".r private val clsName = this.getClass.getCanonicalName @@ -226,7 +227,11 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { s"$padding$thisNode\n${subqueriesSimplified.mkString("")}${childrenSimplified.mkString("")}" } - simplifyNode(plan, 0) + val simplified = simplifyNode(plan, 0) + val propagatedFilterIdMap = new mutable.HashMap[String, String]() + propagatedFilterIdRegex.replaceAllIn(simplified, + m => propagatedFilterIdMap.getOrElseUpdate( + s"$m", s"${m.group("prefix")}${propagatedFilterIdMap.size + 1}")) } private def normalizeIds(plan: String): String = { @@ -237,8 +242,14 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { // Normalize the plan ids in Exchange and Subquery nodes. // See `Exchange.stringArgs` and `SubqueryExec.stringArgs` val planIdMap = new mutable.HashMap[String, String]() - planIdRegex.replaceAllIn(exprIdNormalized, + val planIdNormalized = planIdRegex.replaceAllIn(exprIdNormalized, m => planIdMap.getOrElseUpdate(s"$m", s"${m.group("prefix")}${planIdMap.size + 1}")) + + // Normalize propagatedFilter aliases introduced by PlanMerger's filter propagation. + val propagatedFilterIdMap = new mutable.HashMap[String, String]() + propagatedFilterIdRegex.replaceAllIn(planIdNormalized, + m => propagatedFilterIdMap.getOrElseUpdate( + s"$m", s"${m.group("prefix")}${propagatedFilterIdMap.size + 1}")) } private def normalizeLocation(plan: String): String = {