From eaf2343bfbebb18ea0819b0768180a9afa2b4e68 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 22 Feb 2024 14:01:56 +0800 Subject: [PATCH 1/2] init --- python/pyspark/sql/tests/test_readwriter.py | 23 ++++++++++++- .../sql/catalyst/analysis/Analyzer.scala | 32 +++++++++++++++---- .../analysis/ColumnResolutionHelper.scala | 8 +++++ .../catalyst/plans/logical/LogicalPlan.scala | 5 ++- .../sql/catalyst/rules/RuleExecutor.scala | 4 +++ 5 files changed, 63 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 70a320fc53b69..85057f37a1817 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -20,7 +20,7 @@ import tempfile from pyspark.errors import AnalysisException -from pyspark.sql.functions import col +from pyspark.sql.functions import col, lit from pyspark.sql.readwriter import DataFrameWriterV2 from pyspark.sql.types import StructType, StructField, StringType from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -181,6 +181,27 @@ def test_insert_into(self): df.write.mode("overwrite").insertInto("test_table", False) self.assertEqual(6, self.spark.sql("select * from test_table").count()) + def test_cached_table(self): + with self.table("test_cached_table_1"): + self.spark.range(10).withColumn( + "value_1", + lit(1), + ).write.saveAsTable("test_cached_table_1") + + with self.table("test_cached_table_2"): + self.spark.range(10).withColumnRenamed("id", "index").withColumn( + "value_2", lit(2) + ).write.saveAsTable("test_cached_table_2") + + df1 = self.spark.read.table("test_cached_table_1") + df2 = self.spark.read.table("test_cached_table_2") + df3 = self.spark.read.table("test_cached_table_1") + + join1 = df1.join(df2, on=df1.id == df2.index).select(df2.index, df2.value_2) + join2 = df3.join(join1, how="left", on=join1.index == df3.id) + + self.assertEqual(join2.columns, ["id", "value_1", "index", "value_2"]) + class ReadwriterV2TestsMixin: def test_api(self): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d8127fe03da4e..a4ff505b7ec5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1043,6 +1043,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor /** * Replaces unresolved relations (tables and views) with concrete relations from the catalog. */ + // scalastyle:off println object ResolveRelations extends Rule[LogicalPlan] { // The current catalog and namespace may be different from when the view was created, we must // resolve the view logical plan here, with the catalog and namespace stored in view metadata. @@ -1259,6 +1260,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private def resolveRelation( u: UnresolvedRelation, timeTravelSpec: Option[TimeTravelSpec] = None): Option[LogicalPlan] = { + println() + println(s"resolving $u with id = ${u.getTagValue(LogicalPlan.PLAN_ID_TAG)}") + println() + val timeTravelSpecFromOptions = TimeTravelSpec.fromOptions( u.options, conf.getConf(SQLConf.TIME_TRAVEL_TIMESTAMP_KEY), @@ -1275,16 +1280,29 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val key = ((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq, finalTimeTravelSpec) - AnalysisContext.get.relationCache.get(key).map(_.transform { - case multi: MultiInstanceRelation => - val newRelation = multi.newInstance() - newRelation.copyTagsFrom(multi) - newRelation - }).orElse { + AnalysisContext.get.relationCache.get(key).map { cache => + val cachedRelation = cache.transform { + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation + } + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + val cachedConnectRelation = cachedRelation.clone() + cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + cachedConnectRelation + }.getOrElse(cachedRelation) + }.orElse { val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec) val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) - loaded + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + loaded.map { loadedRelation => + val loadedConnectRelation = loadedRelation.clone() + loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + loadedConnectRelation + } + }.getOrElse(loaded) } case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 8ea50e2ceb659..5f90dc380390e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -518,6 +518,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { case _ => e } + // scalastyle:off println private def resolveDataFrameColumn( u: UnresolvedAttribute, q: Seq[LogicalPlan]): Option[NamedExpression] = { @@ -533,6 +534,13 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) // df1.select(df2.a) <- illegal reference df2.a + println() + println() + println() + println(s"Can not resolve $u with plan $planId") + println() + println() + println() throw QueryCompilationErrors.cannotResolveDataFrameColumn(u) } resolved diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e1121d1f9026e..593809ff77182 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -103,7 +103,10 @@ abstract class LogicalPlan */ lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved - override protected def statePrefix = if (!resolved) "'" else super.statePrefix + override protected def statePrefix = { + val p = if (!resolved) "'" else super.statePrefix + this.getTagValue(LogicalPlan.PLAN_ID_TAG).map(id => s"$p[#$id]").getOrElse(p) + } /** * Returns true if all its children of this query plan have been resolved. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index d5cd5a90e3382..d9f5936ec1b01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -45,6 +45,7 @@ object RuleExecutor { } } +// scalastyle:off println class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging { private val logLevel = SQLConf.get.planChangeLogLevel @@ -64,6 +65,7 @@ class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging { } logBasedOnLevel(message()) + println(message()) } } } @@ -82,6 +84,7 @@ class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging { } logBasedOnLevel(message()) + println(message()) } } @@ -98,6 +101,7 @@ class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging { """.stripMargin logBasedOnLevel(message) + println(message) } private def logBasedOnLevel(f: => String): Unit = { From 369b8b8d168608450991e3d85ec2d90287dca9b1 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 22 Feb 2024 14:03:09 +0800 Subject: [PATCH 2/2] delete debug codes --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 ----- .../sql/catalyst/analysis/ColumnResolutionHelper.scala | 8 -------- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 5 +---- .../apache/spark/sql/catalyst/rules/RuleExecutor.scala | 4 ---- 4 files changed, 1 insertion(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a4ff505b7ec5d..1fb5d00bdf39a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1043,7 +1043,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor /** * Replaces unresolved relations (tables and views) with concrete relations from the catalog. */ - // scalastyle:off println object ResolveRelations extends Rule[LogicalPlan] { // The current catalog and namespace may be different from when the view was created, we must // resolve the view logical plan here, with the catalog and namespace stored in view metadata. @@ -1260,10 +1259,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private def resolveRelation( u: UnresolvedRelation, timeTravelSpec: Option[TimeTravelSpec] = None): Option[LogicalPlan] = { - println() - println(s"resolving $u with id = ${u.getTagValue(LogicalPlan.PLAN_ID_TAG)}") - println() - val timeTravelSpecFromOptions = TimeTravelSpec.fromOptions( u.options, conf.getConf(SQLConf.TIME_TRAVEL_TIMESTAMP_KEY), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 5f90dc380390e..8ea50e2ceb659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -518,7 +518,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { case _ => e } - // scalastyle:off println private def resolveDataFrameColumn( u: UnresolvedAttribute, q: Seq[LogicalPlan]): Option[NamedExpression] = { @@ -534,13 +533,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) // df1.select(df2.a) <- illegal reference df2.a - println() - println() - println() - println(s"Can not resolve $u with plan $planId") - println() - println() - println() throw QueryCompilationErrors.cannotResolveDataFrameColumn(u) } resolved diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 593809ff77182..e1121d1f9026e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -103,10 +103,7 @@ abstract class LogicalPlan */ lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved - override protected def statePrefix = { - val p = if (!resolved) "'" else super.statePrefix - this.getTagValue(LogicalPlan.PLAN_ID_TAG).map(id => s"$p[#$id]").getOrElse(p) - } + override protected def statePrefix = if (!resolved) "'" else super.statePrefix /** * Returns true if all its children of this query plan have been resolved. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index d9f5936ec1b01..d5cd5a90e3382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -45,7 +45,6 @@ object RuleExecutor { } } -// scalastyle:off println class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging { private val logLevel = SQLConf.get.planChangeLogLevel @@ -65,7 +64,6 @@ class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging { } logBasedOnLevel(message()) - println(message()) } } } @@ -84,7 +82,6 @@ class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging { } logBasedOnLevel(message()) - println(message()) } } @@ -101,7 +98,6 @@ class PlanChangeLogger[TreeType <: TreeNode[_]] extends Logging { """.stripMargin logBasedOnLevel(message) - println(message) } private def logBasedOnLevel(f: => String): Unit = {