From a76f4c2c284e6c637f9a6f1c474161b3fbda6c78 Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Tue, 1 Dec 2020 13:41:20 +0200 Subject: [PATCH 1/2] [SPARK-33621][SQL] Add a way to inject data source rewrite rules --- .../apache/spark/sql/SparkSessionExtensions.scala | 15 +++++++++++++++ .../sql/internal/BaseSessionStateBuilder.scala | 4 +++- .../spark/sql/SparkSessionExtensionSuite.scala | 6 ++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 6952f4bfd0566..31b46a705d0ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -199,6 +199,21 @@ class SparkSessionExtensions { optimizerRules += builder } + private[this] val dataSourceRewriteRules = mutable.Buffer.empty[RuleBuilder] + + private[sql] def buildDataSourceRewriteRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + dataSourceRewriteRules.map(_.apply(session)) + } + + /** + * Inject an optimizer `Rule` builder that rewrites data source plans into the [[SparkSession]]. + * The injected rules will be executed after the operator optimization batch and before rules + * that depend on stats. + */ + def injectDataSourceRewriteRule(builder: RuleBuilder): Unit = { + dataSourceRewriteRules += builder + } + private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 8101f9e291b44..f51ee11091d02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -273,7 +273,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `optimizer` function. */ - protected def customDataSourceRewriteRules: Seq[Rule[LogicalPlan]] = Nil + protected def customDataSourceRewriteRules: Seq[Rule[LogicalPlan]] = { + extensions.buildDataSourceRewriteRules(session) + } /** * Planner that converts optimized logical plans to physical plans. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 12abd31b99e93..6fccef5d04aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -88,6 +88,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } } + test("inject data source rewrite rule") { + withSession(Seq(_.injectDataSourceRewriteRule(MyRule))) { session => + assert(session.sessionState.optimizer.dataSourceRewriteRules.contains(MyRule(session))) + } + } + test("inject spark planner strategy") { withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session => assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) From 6b38928a7a72862a9b2408dfc931d0ef5961b466 Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Mon, 7 Dec 2020 13:22:34 +0200 Subject: [PATCH 2/2] Review round 1 --- .../scala/org/apache/spark/sql/SparkSessionExtensions.scala | 3 ++- .../org/apache/spark/sql/SparkSessionExtensionSuite.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 31b46a705d0ae..d5d969032a5e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} *
  • Analyzer Rules.
  • *
  • Check Analysis Rules.
  • *
  • Optimizer Rules.
  • + *
  • Data Source Rewrite Rules.
  • *
  • Planning Strategies.
  • *
  • Customized Parser.
  • *
  • (External) Catalog listeners.
  • @@ -202,7 +203,7 @@ class SparkSessionExtensions { private[this] val dataSourceRewriteRules = mutable.Buffer.empty[RuleBuilder] private[sql] def buildDataSourceRewriteRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { - dataSourceRewriteRules.map(_.apply(session)) + dataSourceRewriteRules.map(_.apply(session)).toSeq } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 6fccef5d04aa0..37bb7a5450ae0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -88,7 +88,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite { } } - test("inject data source rewrite rule") { + test("SPARK-33621: inject data source rewrite rule") { withSession(Seq(_.injectDataSourceRewriteRule(MyRule))) { session => assert(session.sessionState.optimizer.dataSourceRewriteRules.contains(MyRule(session))) }