diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index c7fcc6723e450..3276ab5067500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.sources.v2.reader._ /** @@ -46,4 +47,8 @@ case class BatchScanExec( override lazy val inputRDD: RDD[InternalRow] = { new DataSourceRDD(sparkContext, partitions, readerFactory, supportsBatch) } + + override def doCanonicalize(): BatchScanExec = { + this.copy(output = output.map(QueryPlan.normalizeExprId(_, output))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 337aac9ea651d..70d59321d2cfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -85,6 +86,11 @@ abstract class FileScan( override def readSchema(): StructType = StructType(readDataSchema.fields ++ readPartitionSchema.fields) + // Returns whether the two given arrays of [[Filter]]s are equivalent. + protected def equivalentFilters(a: Array[Filter], b: Array[Filter]): Boolean = { + a.sortBy(_.hashCode()).sameElements(b.sortBy(_.hashCode())) + } + private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis private def normalizeName(name: String): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index dc6b67ceb7e55..b129c942ccc53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -34,7 +35,8 @@ case class OrcScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap) + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter]) extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { override def isSplitable(path: Path): Boolean = true @@ -46,4 +48,14 @@ case class OrcScan( OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema) } + + override def equals(obj: Any): Boolean = obj match { + case o: OrcScan => + fileIndex == o.fileIndex && dataSchema == o.dataSchema && + readDataSchema == o.readDataSchema && readPartitionSchema == o.readPartitionSchema && + options == o.options && equivalentFilters(pushedFilters, o.pushedFilters) + case _ => false + } + + override def hashCode(): Int = getClass.hashCode() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 4c1ec520c6ea7..458b98c627be4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -45,7 +45,7 @@ case class OrcScanBuilder( override def build(): Scan = { OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, - readDataSchema(), readPartitionSchema(), options) + readDataSchema(), readPartitionSchema(), options, pushedFilters()) } private var _pushedFilters: Array[Filter] = Array.empty diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala index d088e24e53bfe..4731da47a19dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.IntegerType @@ -47,6 +49,65 @@ class SameResultSuite extends QueryTest with SharedSQLContext { } } + test("FileScan: different orders of data filters and partition filters") { + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "") { + Seq("orc", "json", "csv").foreach { format => + withTempPath { path => + val tmpDir = path.getCanonicalPath + spark.range(10) + .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d") + .write + .partitionBy("a", "b") + .format(format) + .option("header", true) + .save(tmpDir) + val df = spark.read.format(format).option("header", true).load(tmpDir) + // partition filters: a > 1 AND b < 9 + // data filters: c > 1 AND d < 9 + val plan1 = df.where("a > 1 AND b < 9 AND c > 1 AND d < 9").queryExecution.sparkPlan + val plan2 = df.where("b < 9 AND a > 1 AND d < 9 AND c > 1").queryExecution.sparkPlan + assert(plan1.sameResult(plan2)) + val scan1 = getBatchScanExec(plan1) + val scan2 = getBatchScanExec(plan2) + assert(scan1.sameResult(scan2)) + val plan3 = df.where("b < 9 AND a > 1 AND d < 8 AND c > 1").queryExecution.sparkPlan + assert(!plan1.sameResult(plan3)) + // The [[FileScan]]s should have different results if they support filter pushdown. + if (format == "orc") { + val scan3 = getBatchScanExec(plan3) + assert(!scan1.sameResult(scan3)) + } + } + } + } + } + + test("TextScan") { + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "") { + withTempPath { path => + val tmpDir = path.getCanonicalPath + spark.range(10) + .selectExpr("id as a", "id + 1 as b", "cast(id as string) value") + .write + .partitionBy("a", "b") + .text(tmpDir) + val df = spark.read.text(tmpDir) + // partition filters: a > 1 AND b < 9 + // data filters: c > 1 AND d < 9 + val plan1 = df.where("a > 1 AND b < 9 AND value == '3'").queryExecution.sparkPlan + val plan2 = df.where("value == '3' AND a > 1 AND b < 9").queryExecution.sparkPlan + assert(plan1.sameResult(plan2)) + val scan1 = getBatchScanExec(plan1) + val scan2 = getBatchScanExec(plan2) + assert(scan1.sameResult(scan2)) + } + } + } + + private def getBatchScanExec(plan: SparkPlan): BatchScanExec = { + plan.find(_.isInstanceOf[BatchScanExec]).get.asInstanceOf[BatchScanExec] + } + private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = { df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get .asInstanceOf[FileSourceScanExec] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 955c3e3fa6f74..b38f0f7f228a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -414,19 +414,25 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } test("[SPARK-16818] partition pruned file scans implement sameResult correctly") { - withTempPath { path => - val tempDir = path.getCanonicalPath - spark.range(100) - .selectExpr("id", "id as b") - .write - .partitionBy("id") - .parquet(tempDir) - val df = spark.read.parquet(tempDir) - def getPlan(df: DataFrame): SparkPlan = { - df.queryExecution.executedPlan + Seq("orc", "").foreach { useV1ReaderList => + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1ReaderList) { + withTempPath { path => + val tempDir = path.getCanonicalPath + spark.range(100) + .selectExpr("id", "id as b") + .write + .partitionBy("id") + .orc(tempDir) + val df = spark.read.orc(tempDir) + + def getPlan(df: DataFrame): SparkPlan = { + df.queryExecution.executedPlan + } + + assert(getPlan(df.where("id = 2")).sameResult(getPlan(df.where("id = 2")))) + assert(!getPlan(df.where("id = 2")).sameResult(getPlan(df.where("id = 3")))) + } } - assert(getPlan(df.where("id = 2")).sameResult(getPlan(df.where("id = 2")))) - assert(!getPlan(df.where("id = 2")).sameResult(getPlan(df.where("id = 3")))) } }