From 0dc67ade87e76c216979ae0dabd3fc68dfd4da51 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Wed, 9 Mar 2022 09:53:54 -0500 Subject: [PATCH 01/15] Initial commit --- .../apache/spark/util/sketch/BloomFilter.java | 7 + .../spark/util/sketch/BloomFilterImpl.java | 5 + .../expressions/BloomFilterMightContain.scala | 100 ++++ .../aggregate/BloomFilterAggregate.scala | 196 +++++++ .../expressions/objects/objects.scala | 4 +- .../expressions/regexpExpressions.scala | 5 +- .../optimizer/InjectRuntimeFilter.scala | 311 +++++++++++ .../sql/catalyst/trees/TreePatterns.scala | 3 + .../apache/spark/sql/internal/SQLConf.scala | 51 ++ .../spark/sql/execution/SparkOptimizer.scala | 2 + .../sql/BloomFilterAggregateQuerySuite.scala | 202 +++++++ .../spark/sql/InjectRuntimeFilterSuite.scala | 502 ++++++++++++++++++ 12 files changed, 1386 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index c53987ecf6e25..2a6e270a91267 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -163,6 +163,13 @@ int getVersionNumber() { */ public abstract void writeTo(OutputStream out) throws IOException; + /** + * @return the number of set bits in this {@link BloomFilter}. + */ + public long cardinality() { + throw new UnsupportedOperationException("Not implemented"); + } + /** * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close * the stream. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index e7766ee903480..ccf1833af9945 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -207,6 +207,11 @@ public BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeE return this; } + @Override + public long cardinality() { + return this.bits.cardinality(); + } + private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other) throws IncompatibleMergeException { // Duplicates the logic of `isCompatible` here to provide better error message. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala new file mode 100644 index 0000000000000..9a1cf637e5a73 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.io.ByteArrayInputStream + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.BloomFilter + +/** + * An internal scalar function that returns the membership check result (either true or false) + * for values of `valueExpression` in the Bloom filter represented by `bloomFilterExpression`. + * Not that since the function is "might contain", always returning true regardless is not + * wrong. + * Note that this expression requires that `bloomFilterExpression` is either a constant value or + * an uncorrelated scalar subquery. This is sufficient for the Bloom filter join rewrite. + * + * @param bloomFilterExpression the Binary data of Bloom filter. + * @param valueExpression the Long value to be tested for the membership of `bloomFilterExpression`. + */ +case class BloomFilterMightContain( + bloomFilterExpression: Expression, + valueExpression: Expression) extends BinaryExpression { + + override def nullable: Boolean = true + override def left: Expression = bloomFilterExpression + override def right: Expression = valueExpression + override def prettyName: String = "might_contain" + override def dataType: DataType = BooleanType + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = (left.dataType, right.dataType) match { + case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) | + (BinaryType, LongType) => TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${BinaryType.simpleString} followed by a value with ${LongType.simpleString}, " + + s"but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].") + } + if (typeCheckResult.isFailure) { + return typeCheckResult + } + bloomFilterExpression match { + case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess + case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " + + s"should be either a constant value or a scalar subquery expression") + } + } + + override protected def withNewChildrenInternal( + newBloomFilterExpression: Expression, + newValueExpression: Expression): BloomFilterMightContain = + copy(bloomFilterExpression = newBloomFilterExpression, + valueExpression = newValueExpression) + + // The bloom filter created from `bloomFilterExpression`. + @transient private var bloomFilter: BloomFilter = _ + + override def nullSafeEval(bloomFilterBytes: Any, value: Any): Any = { + if (bloomFilter == null) { + bloomFilter = deserialize(bloomFilterBytes.asInstanceOf[Array[Byte]]) + } + bloomFilter.mightContainLong(value.asInstanceOf[Long]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val thisObj = ctx.addReferenceObj("thisObj", this) + nullSafeCodeGen(ctx, ev, (bloomFilterBytes, value) => { + s"\n${ev.value} = (Boolean) $thisObj.nullSafeEval($bloomFilterBytes, $value);\n" + }) + } + + final def deserialize(bytes: Array[Byte]): BloomFilter = { + val in = new ByteArrayInputStream(bytes) + val bloomFilter = BloomFilter.readFrom(in) + in.close() + bloomFilter + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala new file mode 100644 index 0000000000000..86d3d62e1c643 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.BloomFilter + +/** + * An internal aggregate function that creates a Bloom filter from input values. + * + * @param child Child expression of Long values for creating a Bloom filter. + * @param estimatedNumItemsExpression The number of estimated distinct items (optional). + * @param numBitsExpression The number of bits to use (optional). + */ +case class BloomFilterAggregate( + child: Expression, + estimatedNumItemsExpression: Expression, + numBitsExpression: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[BloomFilter] with TernaryLike[Expression] { + + def this(child: Expression, estimatedNumItemsExpression: Expression, + numBitsExpression: Expression) = { + this(child, estimatedNumItemsExpression, numBitsExpression, 0, 0) + } + + def this(child: Expression, estimatedNumItemsExpression: Expression) = { + this(child, estimatedNumItemsExpression, + // 1 byte per item. + Multiply(estimatedNumItemsExpression, Literal(8L))) + } + + def this(child: Expression) = { + this(child, Literal(BloomFilterAggregate.DEFAULT_EXPECTED_NUM_ITEMS), + Literal(BloomFilterAggregate.DEFAULT_NUM_BITS)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = (first.dataType, second.dataType, third.dataType) match { + case (_, NullType, _) | (_, _, NullType) => + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as size arguments") + case (LongType, LongType, LongType) => TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been a ${LongType.simpleString} value followed with two ${LongType.simpleString} size " + + s"arguments, but it's [${first.dataType.catalogString}, " + + s"${second.dataType.catalogString}, ${third.dataType.catalogString}]") + } + if (typeCheckResult.isFailure) { + return typeCheckResult + } + if (!estimatedNumItemsExpression.foldable) { + TypeCheckFailure("The estimated number of items provided must be a constant literal") + } else if (estimatedNumItems <= 0L) { + TypeCheckFailure("The estimated number of items must be a positive value " + + s" (current value = $estimatedNumItems)") + } else if (!numBitsExpression.foldable) { + TypeCheckFailure("The number of bits provided must be a constant literal") + } else if (numBits <= 0L) { + TypeCheckFailure("The number of bits must be a positive value " + + s" (current value = $numBits)") + } else { + require(estimatedNumItems <= BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS) + require(numBits <= BloomFilterAggregate.MAX_NUM_BITS) + TypeCheckSuccess + } + } + override def nullable: Boolean = true + + override def dataType: DataType = BinaryType + + override def prettyName: String = "bloom_filter_agg" + + // Mark as lazy so that `estimatedNumItems` is not evaluated during tree transformation. + private lazy val estimatedNumItems: Long = + Math.min(estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue, + BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS) + + // Mark as lazy so that `numBits` is not evaluated during tree transformation. + private lazy val numBits: Long = + Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, + BloomFilterAggregate.MAX_NUM_BITS) + + override def first: Expression = child + + override def second: Expression = estimatedNumItemsExpression + + override def third: Expression = numBitsExpression + + override protected def withNewChildrenInternal(newChild: Expression, + newEstimatedNumItemsExpression: Expression, newNumBitsExpression: Expression) + : BloomFilterAggregate = { + copy(child = newChild, estimatedNumItemsExpression = newEstimatedNumItemsExpression, + numBitsExpression = newNumBitsExpression) + } + + override def createAggregationBuffer(): BloomFilter = { + BloomFilter.create(estimatedNumItems, numBits) + } + + override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = { + val value = child.eval(inputRow) + // Ignore null values. + if (value == null) { + return buffer + } + buffer.putLong(value.asInstanceOf[Long]) + buffer + } + + override def merge(buffer: BloomFilter, other: BloomFilter): BloomFilter = { + buffer.mergeInPlace(other) + } + + override def eval(buffer: BloomFilter): Any = { + if (buffer.cardinality() == 0) { + // There's no set bit in the Bloom filter and hence no not-null value is processed. + return null + } + serialize(buffer) + } + + override def withNewMutableAggBufferOffset(newOffset: Int): BloomFilterAggregate = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): BloomFilterAggregate = + copy(inputAggBufferOffset = newOffset) + + override def serialize(obj: BloomFilter): Array[Byte] = { + BloomFilterAggregate.serde.serialize(obj) + } + + override def deserialize(bytes: Array[Byte]): BloomFilter = { + BloomFilterAggregate.serde.deserialize(bytes) + } +} + +object BloomFilterAggregate { + + val DEFAULT_EXPECTED_NUM_ITEMS: Long = 1000000L // Default 1M distinct items + + val MAX_ALLOWED_NUM_ITEMS: Long = 4000000L // At most 4M distinct items + + val DEFAULT_NUM_BITS: Long = 8388608 // Default 1MB + + val MAX_NUM_BITS: Long = 67108864 // At most 8MB + + /** + * Serializer/Deserializer for class [[BloomFilter]] + * + * This class is thread safe. + */ + class BloomFilterSerDe { + + final def serialize(obj: BloomFilter): Array[Byte] = { + val size = obj.bitSize()/8 + require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size") + val out = new ByteArrayOutputStream(size.intValue()) + obj.writeTo(out) + out.close() + out.toByteArray + } + + final def deserialize(bytes: Array[Byte]): BloomFilter = { + val in = new ByteArrayInputStream(bytes) + val bloomFilter = BloomFilter.readFrom(in) + in.close() + bloomFilter + } + } + + val serde: BloomFilterSerDe = new BloomFilterSerDe +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 4599c2a2d3055..130a59eb1634a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TernaryLike -import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, _} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -360,6 +360,8 @@ case class Invoke( lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + final override val nodePatterns: Seq[TreePattern] = Seq(INVOKE) + override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 368cbfd6be641..bfaaba514462f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -627,6 +627,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio @transient private var lastReplacementInUTF8: UTF8String = _ // result buffer write by Matcher @transient private lazy val result: StringBuffer = new StringBuffer + final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE) override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = { if (!p.equals(lastRegex)) { @@ -751,6 +752,8 @@ abstract class RegExpExtractBase // last regex pattern, we cache it for performance concern @transient private var pattern: Pattern = _ + final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) override def first: Expression = subject override def second: Expression = regexp diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala new file mode 100644 index 0000000000000..1118cb40551a1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete} +import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Insert a filter on one side of the join if the other side has a selective predicate. + * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something + * else in the future. + */ +object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { + + // Wraps `expr` with a hash function if its byte size is larger than an integer. + private def mayWrapWithHash(expr: Expression): Expression = { + if (expr.dataType.defaultSize > IntegerType.defaultSize) { + new Murmur3Hash(Seq(expr)) + } else { + expr + } + } + + private def injectFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan + ): LogicalPlan = { + require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled) + if (conf.runtimeFilterBloomFilterEnabled) { + injectBloomFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } else { + injectInSubqueryFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } + } + + private def injectBloomFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan + ): LogicalPlan = { + // Skip if the filter creation side is too big + if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterBloomFilterThreshold) { + return filterApplicationSidePlan + } + val rowCount = filterCreationSidePlan.stats.rowCount + val bloomFilterAgg = + if (rowCount.isDefined && rowCount.get.longValue > 0L) { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), + Literal(rowCount.get.longValue)) + } else { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) + } + val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) + val alias = Alias(aggExp, "bloomFilter")() + val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan)) + val bloomFilterSubquery = ScalarSubquery(aggregate, Nil) + val filter = BloomFilterMightContain(bloomFilterSubquery, + new XxHash64(Seq(filterApplicationSideExp))) + Filter(filter, filterApplicationSidePlan) + } + + private def injectInSubqueryFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan + ): LogicalPlan = { + require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) + val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) + val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() + val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) + if (!canBroadcastBySize(aggregate, conf)) { + // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, + // i.e., the semi-join will be a shuffled join, which is not worthwhile. + return filterApplicationSidePlan + } + val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)), + ListQuery(aggregate, childOutputs = aggregate.output)) + Filter(filter, filterApplicationSidePlan) + } + + /** + * Returns whether the plan is a simple filter over scan and the filter is likely selective + * Also check if the plan only has simple expressions (attribute reference, literals) so that we + * do not add a subquery that might have an expensive computation + */ + private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { + plan.expressions + val ret = plan match { + case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => + filters.forall(isSimpleExpression) && + filters.exists(isLikelySelective) + case _ => false + } + !plan.isStreaming && ret + } + + private def isSimpleExpression(e: Expression): Boolean = { + !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, + REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE) + } + + /** + * Returns whether an expression is likely to be selective + */ + private def isLikelySelective(e: Expression): Boolean = e match { + case Not(expr) => isLikelySelective(expr) + case And(l, r) => isLikelySelective(l) || isLikelySelective(r) + case Or(l, r) => isLikelySelective(l) && isLikelySelective(r) + case _: StringRegexExpression => true + case _: BinaryComparison => true + case _: In | _: InSet => true + case _: StringPredicate => true + case _: MultiLikeBase => true + case _ => false + } + + private def canFilterLeft(joinType: JoinType): Boolean = joinType match { + case Inner | RightOuter => true + case _ => false + } + + private def canFilterRight(joinType: JoinType): Boolean = joinType match { + case Inner | LeftOuter => true + case _ => false + } + + private def isProbablyShuffleJoin(left: LogicalPlan, + right: LogicalPlan, hint: JoinHint): Boolean = { + !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && + !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) + } + + private def probablyHasShuffle(plan: LogicalPlan): Boolean = { + plan.collect { + case j@Join(left, right, _, _, hint) + if !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && + !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) => j + case a: Aggregate => a + }.nonEmpty + } + + // Returns the max scan byte size in the subtree rooted at `filterApplicationSide`. + private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = { + val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES) + filterApplicationSide.collect({ + case leaf: LeafNode => leaf + }).map(scan => { + // DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid + // creating a Bloom filter when the filter application side is very small, so using 0 + // as the byte size when the actual size is unknown can avoid regression by applying BF + // on a small table. + if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes + }).max + } + + // Returns true if `filterApplicationSide` satisfies the byte size requirement to apply a + // Bloom filter; false otherwise. + private def satisfyByteSizeRequirement(filterApplicationSide: LogicalPlan): Boolean = { + // In case `filterApplicationSide` is a union of many small tables, disseminating the Bloom + // filter to each small task might be more costly than scanning them itself. Thus, we use max + // rather than sum here. + val maxScanSize = maxScanByteSize(filterApplicationSide) + maxScanSize >= + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD) + } + + private def filteringHasBenefit( + filterApplicationSide: LogicalPlan, + filterCreationSide: LogicalPlan, + filterApplicationSideExp: Expression, + hint: JoinHint): Boolean = { + // Check that: + // 1. The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the + // expression references originate from a single leaf node) + // 2. The filter creation side has a selective predicate + // 3. The current join is a shuffle join or a broadcast join that has a shuffle or aggregate + // in the filter application side + // 4. The filterApplicationSide is larger than the filterCreationSide by a configurable + // threshold + findExpressionAndTrackLineageDown(filterApplicationSideExp, + filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) && + (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) || + probablyHasShuffle(filterApplicationSide)) && + satisfyByteSizeRequirement(filterApplicationSide) + } + + def hasRuntimeFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + if (conf.runtimeFilterBloomFilterEnabled) { + hasBloomFilter(left, right, leftKey, rightKey) + } else { + hasInSubquery(left, right, leftKey, rightKey) + } + } + + // This checks if there is already a DPP filter, as this rule is called just after DPP. + def hasDynamicPruningSubquery(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + (left, right) match { + case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => + pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey) + case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) => + pruningKey.fastEquals(rightKey) || + hasDynamicPruningSubquery(left, plan, leftKey, rightKey) + case _ => false + } + } + + def hasBloomFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + findBloomFilterWithExp(left, leftKey) || findBloomFilterWithExp(right, rightKey) + } + + private def findBloomFilterWithExp(plan: LogicalPlan, key: Expression): Boolean = { + plan.find { + case Filter(condition, _) => + splitConjunctivePredicates(condition).exists { + case BloomFilterMightContain(_, XxHash64(Seq(valueExpression), _)) + if valueExpression.fastEquals(key) => true + case _ => false + } + case _ => false + }.isDefined + } + + def hasInSubquery(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + (left, right) match { + case (Filter(InSubquery(Seq(key), + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) => + key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey))) + case (_, Filter(InSubquery(Seq(key), + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) => + key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey))) + case _ => false + } + } + + private def tryInjectRuntimeFilter(plan: LogicalPlan): LogicalPlan = { + var filterCounter = 0 + val numFilterThreshold = conf.getConf(SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD) + plan transformUp { + case join @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, left, right, hint) => + var newLeft = left + var newRight = right + (leftKeys, rightKeys).zipped.foreach((l, r) => { + // Check if: + // 1. There is already a DPP filter on the key + // 2. There is already a runtime filter (Bloom filter or IN subquery) on the key + // 3. The keys are simple cheap expressions + if (filterCounter < numFilterThreshold && + !hasDynamicPruningSubquery(left, right, l, r) && + !hasRuntimeFilter(newLeft, newRight, l, r) && + isSimpleExpression(l) && isSimpleExpression(r)) { + if (canFilterLeft(joinType) && filteringHasBenefit(left, right, l, hint)) { + newLeft = injectFilter(l, newLeft, r, right) + filterCounter = filterCounter + 1 + } else if (canFilterRight(joinType) && filteringHasBenefit(right, left, r, hint)) { + newRight = injectFilter(r, newRight, l, left) + filterCounter = filterCounter + 1 + } + } + }) + Join(newLeft, newRight, joinType, join.condition, hint) + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case s: Subquery if s.correlated => plan + case _ if !conf.runtimeFilterSemiJoinReductionEnabled && + !conf.runtimeFilterBloomFilterEnabled => plan + case _ => tryInjectRuntimeFilter(plan) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index b595966bcc235..3cf45d5f79f00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -54,6 +54,7 @@ object TreePattern extends Enumeration { val IN_SUBQUERY: Value = Value val INSET: Value = Value val INTERSECT: Value = Value + val INVOKE: Value = Value val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value @@ -72,6 +73,8 @@ object TreePattern extends Enumeration { val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value + val REGEXP_EXTRACT_FAMILY: Value = Value + val REGEXP_REPLACE: Value = Value val RUNTIME_REPLACEABLE: Value = Value val SCALAR_SUBQUERY: Value = Value val SCALA_UDF: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a050156518c2c..80e78e7cb8e6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -341,6 +341,48 @@ object SQLConf { .booleanConf .createWithDefault(true) + val RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED = + buildConf("spark.sql.optimizer.runtimeFilter.semiJoinReduction.enabled") + .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " + + "to insert a semi join in the other side to reduce the amount of shuffle data") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val RUNTIME_FILTER_NUMBER_THRESHOLD = + buildConf("spark.sql.optimizer.runtimeFilter.number.threshold") + .doc("The total number of injected runtime filters (non-DPP) for a single " + + "query. This is to prevent driver OOMs with too many Bloom filters") + .version("3.3.0") + .intConf + .checkValue(threshold => threshold >= 0, "The threshold should be >= 0") + .createWithDefault(10) + + lazy val RUNTIME_BLOOM_FILTER_ENABLED = + buildConf("spark.sql.optimizer.runtime.bloomFilter.enabled") + .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " + + "to insert a bloom filter in the other side to reduce the amount of shuffle data") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val RUNTIME_BLOOM_FILTER_THRESHOLD = + buildConf("spark.sql.optimizer.runtime.bloomFilter.threshold") + .doc("Size threshold of the bloom filter creation side plan. Estimated size needs to be " + + "under this value to try to inject bloom filter") + .version("3.3.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("10MB") + + val RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD = + buildConf("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizethreshold") + .doc("Byte size threshold of the Bloom filter application side plan's aggregated scan " + + "size. Aggregated scan byte size of the Bloom filter application side needs to be over " + + "this value to inject a bloom filter") + .version("3.3.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("10GB") + val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") @@ -3713,6 +3755,15 @@ class SQLConf extends Serializable with Logging { def dynamicPartitionPruningReuseBroadcastOnly: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY) + def runtimeFilterSemiJoinReductionEnabled: Boolean = + getConf(RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED) + + def runtimeFilterBloomFilterEnabled: Boolean = + getConf(RUNTIME_BLOOM_FILTER_ENABLED) + + def runtimeFilterBloomFilterThreshold: Long = + getConf(RUNTIME_BLOOM_FILTER_THRESHOLD) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index dc3ceb5c595d0..f2094b6b86722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -43,6 +43,8 @@ class SparkOptimizer( Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("PartitionPruning", Once, PartitionPruning) :+ + Batch("InjectRuntimeFilter", FixedPoint(1), + InjectRuntimeFilter) :+ Batch("Pushdown Filters from PartitionPruning", fixedPoint, PushDownPredicates) :+ Batch("Cleanup filters that cannot be pushed down", Once, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala new file mode 100644 index 0000000000000..37b4f5f7c0f6d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Query tests for the Bloom filter aggregate and filter function. + */ +class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + // Register 'bloom_filter_agg' to builtin. + FunctionRegistry.builtin.registerFunction(new FunctionIdentifier("bloom_filter_agg"), + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + // Register 'might_contain' to builtin. + FunctionRegistry.builtin.registerFunction(new FunctionIdentifier("might_contain"), + new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), + (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + + test("Test bloom_filter_agg and might_contain") { + val table = "bloom_filter_test" + for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, + BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS)) { + for (numBits <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, + BloomFilterAggregate.MAX_NUM_BITS)) { + val sqlString = s""" + |SELECT every(might_contain( + | (SELECT bloom_filter_agg(col, + | cast($numEstimatedItems as long), + | cast($numBits as long)) + | FROM $table), + | col)) positive_membership_test, + | every(might_contain( + | (SELECT bloom_filter_agg(col, + | cast($numEstimatedItems as long), + | cast($numBits as long)) + | FROM values (-1L), (100001L), (20000L) as t(col)), + | col)) negative_membership_test + |FROM $table + """.stripMargin + withTempView(table) { + (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 10000L)) + .toDF("col").createOrReplaceTempView(table) + // Validate error messages as well as answers when there's no error. + if (numEstimatedItems <= 0) { + val exception = intercept[AnalysisException] { + spark.sql(sqlString) + } + assert(exception.getMessage.contains( + "The estimated number of items must be a positive value")) + } else if (numBits <= 0) { + val exception = intercept[AnalysisException] { + spark.sql(sqlString) + } + assert(exception.getMessage.contains("The number of bits must be a positive value")) + } else { + checkAnswer(spark.sql(sqlString), Row(true, false)) + } + } + } + } + } + + test("Test that bloom_filter_agg errors out disallowed input value types") { + val exception1 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a) + |FROM values (1.2), (2.5) as t(a)""" + .stripMargin) + } + assert(exception1.getMessage.contains( + "Input to function bloom_filter_agg should have been a bigint value")) + + val exception2 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a, 2) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception2.getMessage.contains( + "function bloom_filter_agg should have been a bigint value followed with two bigint")) + + val exception3 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a, cast(2 as long), 5) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception3.getMessage.contains( + "function bloom_filter_agg should have been a bigint value followed with two bigint")) + + val exception4 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a, null, 5) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception4.getMessage.contains("Null typed values cannot be used as size arguments")) + + val exception5 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a, 5, null) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception5.getMessage.contains("Null typed values cannot be used as size arguments")) + } + + test("Test that might_contain errors out disallowed input value types") { + val exception1 = intercept[AnalysisException] { + spark.sql("""|SELECT might_contain(1.0, 1L)""" + .stripMargin) + } + assert(exception1.getMessage.contains( + "Input to function might_contain should have been binary followed by a value with bigint")) + + val exception2 = intercept[AnalysisException] { + spark.sql("""|SELECT might_contain(NULL, 0.1)""" + .stripMargin) + } + assert(exception2.getMessage.contains( + "Input to function might_contain should have been binary followed by a value with bigint")) + } + + test("Test that might_contain errors out non-constant Bloom filter") { + val exception1 = intercept[AnalysisException] { + spark.sql(""" + |SELECT might_contain(cast(a as binary), cast(5 as long)) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception1.getMessage.contains( + "The Bloom filter binary input to might_contain should be either a constant value or " + + "a scalar subquery expression")) + + val exception2 = intercept[AnalysisException] { + spark.sql(""" + |SELECT might_contain((select cast(a as binary)), cast(5 as long)) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception2.getMessage.contains( + "The Bloom filter binary input to might_contain should be either a constant value or " + + "a scalar subquery expression")) + } + + test("Test that might_contain can take a constant value input") { + checkAnswer(spark.sql( + """SELECT might_contain( + |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', + |cast(201 as long))""".stripMargin), + Row(false)) + } + + test("Test that bloom_filter_agg produces a NULL with empty input") { + checkAnswer(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""), + Row(null)) + } + + test("Test NULL inputs for might_contain") { + checkAnswer(spark.sql( + s""" + |SELECT might_contain(null, null) both_null, + | might_contain(null, 1L) null_bf, + | might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)), + | null) null_value + """.stripMargin), + Row(null, null, null)) + } + + test("Test that a query with bloom_filter_agg has partial aggregates") { + spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""") + .queryExecution.executedPlan.collect({case agg: BaseAggregateExec => agg}).size == 2 + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala new file mode 100644 index 0000000000000..ab1f0e61759d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -0,0 +1,502 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} +import org.apache.spark.sql.types.{IntegerType, StructType} + +class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession { + + protected override def beforeAll(): Unit = { + super.beforeAll() + val schema = new StructType().add("a1", IntegerType, nullable = true) + .add("b1", IntegerType, nullable = true) + .add("c1", IntegerType, nullable = true) + .add("d1", IntegerType, nullable = true) + .add("e1", IntegerType, nullable = true) + .add("f1", IntegerType, nullable = true) + + val data1 = Seq(Seq(null, 47, null, 4, 6, 48), + Seq(73, 63, null, 92, null, null), + Seq(76, 10, 74, 98, 37, 5), + Seq(0, 63, null, null, null, null), + Seq(15, 77, null, null, null, null), + Seq(null, 57, 33, 55, null, 58), + Seq(4, 0, 86, null, 96, 14), + Seq(28, 16, 58, null, null, null), + Seq(1, 88, null, 8, null, 79), + Seq(59, null, null, null, 20, 25), + Seq(1, 50, null, 94, 94, null), + Seq(null, null, null, 67, 51, 57), + Seq(77, 50, 8, 90, 16, 21), + Seq(34, 28, null, 5, null, 64), + Seq(null, null, 88, 11, 63, 79), + Seq(92, 94, 23, 1, null, 64), + Seq(57, 56, null, 83, null, null), + Seq(null, 35, 8, 35, null, 70), + Seq(null, 8, null, 35, null, 87), + Seq(9, null, null, 60, null, 5), + Seq(null, 15, 66, null, 83, null)) + val rdd1 = spark.sparkContext.parallelize(data1) + val rddRow1 = rdd1.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow1, schema).write.saveAsTable("bf1") + + val schema2 = new StructType().add("a2", IntegerType, nullable = true) + .add("b2", IntegerType, nullable = true) + .add("c2", IntegerType, nullable = true) + .add("d2", IntegerType, nullable = true) + .add("e2", IntegerType, nullable = true) + .add("f2", IntegerType, nullable = true) + + + val data2 = Seq(Seq(67, 17, 45, 91, null, null), + Seq(98, 63, 0, 89, null, 40), + Seq(null, 76, 68, 75, 20, 19), + Seq(8, null, null, null, 78, null), + Seq(48, 62, null, null, 11, 98), + Seq(84, null, 99, 65, 66, 51), + Seq(98, null, null, null, 42, 51), + Seq(10, 3, 29, null, 68, 8), + Seq(85, 36, 41, null, 28, 71), + Seq(89, null, 94, 95, 67, 21), + Seq(44, null, 24, 33, null, 6), + Seq(null, 6, 78, 31, null, 69), + Seq(59, 2, 63, 9, 66, 20), + Seq(5, 23, 10, 86, 68, null), + Seq(null, 63, 99, 55, 9, 65), + Seq(57, 62, 68, 5, null, 0), + Seq(75, null, 15, null, 81, null), + Seq(53, null, 6, 68, 28, 13), + Seq(null, null, null, null, 89, 23), + Seq(36, 73, 40, null, 8, null), + Seq(24, null, null, 40, null, null)) + val rdd2 = spark.sparkContext.parallelize(data2) + val rddRow2 = rdd2.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow2, schema2).write.saveAsTable("bf2") + + val schema3 = new StructType().add("a3", IntegerType, nullable = true) + .add("b3", IntegerType, nullable = true) + .add("c3", IntegerType, nullable = true) + .add("d3", IntegerType, nullable = true) + .add("e3", IntegerType, nullable = true) + .add("f3", IntegerType, nullable = true) + + val data3 = Seq(Seq(67, 17, 45, 91, null, null), + Seq(98, 63, 0, 89, null, 40), + Seq(null, 76, 68, 75, 20, 19), + Seq(8, null, null, null, 78, null), + Seq(48, 62, null, null, 11, 98), + Seq(84, null, 99, 65, 66, 51), + Seq(98, null, null, null, 42, 51), + Seq(10, 3, 29, null, 68, 8), + Seq(85, 36, 41, null, 28, 71), + Seq(89, null, 94, 95, 67, 21), + Seq(44, null, 24, 33, null, 6), + Seq(null, 6, 78, 31, null, 69), + Seq(59, 2, 63, 9, 66, 20), + Seq(5, 23, 10, 86, 68, null), + Seq(null, 63, 99, 55, 9, 65), + Seq(57, 62, 68, 5, null, 0), + Seq(75, null, 15, null, 81, null), + Seq(53, null, 6, 68, 28, 13), + Seq(null, null, null, null, 89, 23), + Seq(36, 73, 40, null, 8, null), + Seq(24, null, null, 40, null, null)) + val rdd3 = spark.sparkContext.parallelize(data3) + val rddRow3 = rdd3.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow3, schema3).write.saveAsTable("bf3") + + + val schema4 = new StructType().add("a4", IntegerType, nullable = true) + .add("b4", IntegerType, nullable = true) + .add("c4", IntegerType, nullable = true) + .add("d4", IntegerType, nullable = true) + .add("e4", IntegerType, nullable = true) + .add("f4", IntegerType, nullable = true) + + val data4 = Seq(Seq(67, 17, 45, 91, null, null), + Seq(98, 63, 0, 89, null, 40), + Seq(null, 76, 68, 75, 20, 19), + Seq(8, null, null, null, 78, null), + Seq(48, 62, null, null, 11, 98), + Seq(84, null, 99, 65, 66, 51), + Seq(98, null, null, null, 42, 51), + Seq(10, 3, 29, null, 68, 8), + Seq(85, 36, 41, null, 28, 71), + Seq(89, null, 94, 95, 67, 21), + Seq(44, null, 24, 33, null, 6), + Seq(null, 6, 78, 31, null, 69), + Seq(59, 2, 63, 9, 66, 20), + Seq(5, 23, 10, 86, 68, null), + Seq(null, 63, 99, 55, 9, 65), + Seq(57, 62, 68, 5, null, 0), + Seq(75, null, 15, null, 81, null), + Seq(53, null, 6, 68, 28, 13), + Seq(null, null, null, null, 89, 23), + Seq(36, 73, 40, null, 8, null), + Seq(24, null, null, 40, null, null)) + val rdd4 = spark.sparkContext.parallelize(data4) + val rddRow4 = rdd4.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow4, schema4).write.saveAsTable("bf4") + + val schema5part = new StructType().add("a5", IntegerType, nullable = true) + .add("b5", IntegerType, nullable = true) + .add("c5", IntegerType, nullable = true) + .add("d5", IntegerType, nullable = true) + .add("e5", IntegerType, nullable = true) + .add("f5", IntegerType, nullable = true) + + val data5part = Seq(Seq(67, 17, 45, 91, null, null), + Seq(98, 63, 0, 89, null, 40), + Seq(null, 76, 68, 75, 20, 19), + Seq(8, null, null, null, 78, null), + Seq(48, 62, null, null, 11, 98), + Seq(84, null, 99, 65, 66, 51), + Seq(98, null, null, null, 42, 51), + Seq(10, 3, 29, null, 68, 8), + Seq(85, 36, 41, null, 28, 71), + Seq(89, null, 94, 95, 67, 21), + Seq(44, null, 24, 33, null, 6), + Seq(null, 6, 78, 31, null, 69), + Seq(59, 2, 63, 9, 66, 20), + Seq(5, 23, 10, 86, 68, null), + Seq(null, 63, 99, 55, 9, 65), + Seq(57, 62, 68, 5, null, 0), + Seq(75, null, 15, null, 81, null), + Seq(53, null, 6, 68, 28, 13), + Seq(null, null, null, null, 89, 23), + Seq(36, 73, 40, null, 8, null), + Seq(24, null, null, 40, null, null)) + val rdd5part = spark.sparkContext.parallelize(data5part) + val rddRow5part = rdd5part.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow5part, schema5part).write.partitionBy("f5") + .saveAsTable("bf5part") + spark.createDataFrame(rddRow5part, schema5part).filter("a5 > 30") + .write.partitionBy("f5") + .saveAsTable("bf5filtered") + + sql("analyze table bf1 compute statistics for columns a1, b1, c1, d1, e1, f1") + sql("analyze table bf2 compute statistics for columns a2, b2, c2, d2, e2, f2") + sql("analyze table bf3 compute statistics for columns a3, b3, c3, d3, e3, f3") + sql("analyze table bf4 compute statistics for columns a4, b4, c4, d4, e4, f4") + sql("analyze table bf5part compute statistics for columns a5, b5, c5, d5, e5, f5") + sql("analyze table bf5filtered compute statistics for columns a5, b5, c5, d5, e5, f5") + } + + protected override def afterAll(): Unit = try { + sql("DROP TABLE IF EXISTS bf1") + sql("DROP TABLE IF EXISTS bf2") + sql("DROP TABLE IF EXISTS bf3") + sql("DROP TABLE IF EXISTS bf4") + sql("DROP TABLE IF EXISTS bf5part") + } finally { + super.afterAll() + } + + def checkWithAndWithoutFeatureEnabled(query: String, testSemiJoin: Boolean, + shouldReplace: Boolean): Unit = { + var planDisabled: LogicalPlan = null + var planEnabled: LogicalPlan = null + var expectedAnswer: Array[Row] = null + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planDisabled = sql(query).queryExecution.optimizedPlan + expectedAnswer = sql(query).collect() + } + + if (testSemiJoin) { + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "true", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + } + if (shouldReplace) { + val normalizedEnabled = normalizePlan(normalizeExprIds(planEnabled)) + val normalizedDisabled = normalizePlan(normalizeExprIds(planDisabled)) + assert(normalizedEnabled != normalizedDisabled) + } else { + comparePlans(planDisabled, planEnabled) + } + } else { + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + if (shouldReplace) { + assert(getNumBloomFilters(planEnabled) > getNumBloomFilters(planDisabled)) + } else { + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled)) + } + } + } + } + + def getNumBloomFilters(plan: LogicalPlan): Integer = { + val numBloomFilterAggs = plan.collect { + case Filter(condition, _) => condition.collect { + case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery + => subquery.plan.collect { + case Aggregate(_, aggregateExpressions, _) => + aggregateExpressions.map { + case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _, _, _, _), + _) => + assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal]) + assert(bfAgg.numBitsExpression.isInstanceOf[Literal]) + 1 + }.sum + }.sum + }.sum + }.sum + val numMightContains = plan.collect { + case Filter(condition, _) => condition.collect { + case BloomFilterMightContain(_, _) => 1 + }.sum + }.sum + assert(numBloomFilterAggs == numMightContains) + numMightContains + } + + def assertRewroteSemiJoin(query: String): Unit = { + checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = true) + } + + def assertDidNotRewriteSemiJoin(query: String): Unit = { + checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = false) + } + + def assertRewroteWithBloomFilter(query: String): Unit = { + checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = true) + } + + def assertDidNotRewriteWithBloomFilter(query: String): Unit = { + checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = false) + } + + test(s"Runtime semi join reduction: simple") { + // Filter creation side is 3409 bytes + // Filter application side scan is 3362 bytes + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteSemiJoin(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 where bf2.a2 = 62") + assertDidNotRewriteSemiJoin(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2") + } + } + + test(s"Runtime semi join reduction: two joins") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteSemiJoin(s"select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " + + s"and bf3.c3 = bf2.c2 where bf2.a2 = 5") + } + } + + test(s"Runtime semi join reduction: three joins") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteSemiJoin(s"select * from bf1 join bf2 join bf3 join bf4 on " + + s"bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5") + } + } + + test(s"Runtime semi join reduction: simple expressions only") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + val squared = (s: Long) => { + s * s + } + spark.udf.register("square", squared) + assertDidNotRewriteSemiJoin(s"select * from bf1 join bf2 on " + + s"bf1.c1 = bf2.c2 where square(bf2.a2) = 62") + assertDidNotRewriteSemiJoin(s"select * from bf1 join bf2 on " + + s"bf1.c1 = square(bf2.c2) where bf2.a2= 62") + } + } + + test(s"Runtime bloom filter join: simple") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + s"where bf2.a2 = 62") + assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2") + } + } + + test(s"Runtime bloom filter join: two filters single join") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + var planDisabled: LogicalPlan = null + var planEnabled: LogicalPlan = null + var expectedAnswer: Array[Row] = null + + val query = s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + s"bf1.b1 = bf2.b2 where bf2.a2 = 62" + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planDisabled = sql(query).queryExecution.optimizedPlan + expectedAnswer = sql(query).collect() + } + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + } + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2) + } + } + + test(s"Runtime bloom filter join: test the number of filter threshold") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + var planDisabled: LogicalPlan = null + var planEnabled: LogicalPlan = null + var expectedAnswer: Array[Row] = null + + val query = s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + s"bf1.b1 = bf2.b2 where bf2.a2 = 62" + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planDisabled = sql(query).queryExecution.optimizedPlan + expectedAnswer = sql(query).collect() + } + + for (numFilterThreshold <- 0 to 3) { + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true", + SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD.key -> numFilterThreshold.toString) { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + } + if (numFilterThreshold < 3) { + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + + numFilterThreshold) + } else { + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2) + } + } + } + } + + test(s"Runtime bloom filter join: insert one bloom filter per column") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + var planDisabled: LogicalPlan = null + var planEnabled: LogicalPlan = null + var expectedAnswer: Array[Row] = null + + val query = s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + s"bf1.c1 = bf2.b2 where bf2.a2 = 62" + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planDisabled = sql(query).queryExecution.optimizedPlan + expectedAnswer = sql(query).collect() + } + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + } + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 1) + } + } + + test(s"Runtime bloom filter join: do not add bloom filter if dpp filter exists " + + s"on the same column") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertDidNotRewriteWithBloomFilter(s"select * from bf5part join bf2 on " + + s"bf5part.f5 = bf2.c2 where bf2.a2 = 62") + } + } + + test(s"Runtime bloom filter join: add bloom filter if dpp filter exists on " + + s"a different column") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteWithBloomFilter(s"select * from bf5part join bf2 on " + + s"bf5part.c5 = bf2.c2 and bf5part.f5 = bf2.f2 where bf2.a2 = 62") + } + } + + test(s"Runtime bloom filter join: BF rewrite triggering threshold test") { + // Filter creation side data size is 3409 bytes. On the filter application side, an individual + // scan's byte size is 3362. + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", + SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "4000" + ) { + assertRewroteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + s"where bf2.a2 = 62") + } + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", + SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "50" + ) { + assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + s"where bf2.a2 = 62") + } + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", + SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "4000" + ) { + // Rewrite should not be triggered as the Bloom filter application side scan size is small. + assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + s"where bf2.a2 = 62") + } + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32", + SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "4000") { + // Test that the max scan size rather than an individual scan size on the filter + // application side matters. `bf5filtered` has 14168 bytes and `bf2` has 3409 bytes. + withSQLConf( + SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000") { + assertRewroteWithBloomFilter(s"select * from " + + s"(select * from bf5filtered union all select * from bf2) t " + + s"join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") + } + withSQLConf( + SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "15000") { + assertDidNotRewriteWithBloomFilter(s"select * from " + + s"(select * from bf5filtered union all select * from bf2) t " + + s"join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") + } + } + } + + test(s"Runtime bloom filter join: simple expressions only") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + val squared = (s: Long) => { + s * s + } + spark.udf.register("square", squared) + assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on " + + s"bf1.c1 = bf2.c2 where square(bf2.a2) = 62" ) + assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on " + + s"bf1.c1 = square(bf2.c2) where bf2.a2 = 62" ) + } + } +} From c851317a2ef2e47bca135388f5b79a424deb49f4 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Fri, 11 Mar 2022 10:12:10 -0500 Subject: [PATCH 02/15] Dedupicate code --- .../sql/catalyst/expressions/predicates.scala | 16 ++++++++++++++++ .../catalyst/optimizer/InjectRuntimeFilter.scala | 15 --------------- .../dynamicpruning/PartitionPruning.scala | 15 --------------- 3 files changed, 16 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a2fd668f495e0..d16e09c5ed95c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -287,6 +287,22 @@ trait PredicateHelper extends AliasHelper with Logging { } } } + + /** + * Returns whether an expression is likely to be selective + */ + def isLikelySelective(e: Expression): Boolean = e match { + case Not(expr) => isLikelySelective(expr) + case And(l, r) => isLikelySelective(l) || isLikelySelective(r) + case Or(l, r) => isLikelySelective(l) && isLikelySelective(r) + case _: StringRegexExpression => true + case _: BinaryComparison => true + case _: In | _: InSet => true + case _: StringPredicate => true + case BinaryPredicate(_) => true + case _: MultiLikeBase => true + case _ => false + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 1118cb40551a1..020ca39dc3837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -135,21 +135,6 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE) } - /** - * Returns whether an expression is likely to be selective - */ - private def isLikelySelective(e: Expression): Boolean = e match { - case Not(expr) => isLikelySelective(expr) - case And(l, r) => isLikelySelective(l) || isLikelySelective(r) - case Or(l, r) => isLikelySelective(l) && isLikelySelective(r) - case _: StringRegexExpression => true - case _: BinaryComparison => true - case _: In | _: InSet => true - case _: StringPredicate => true - case _: MultiLikeBase => true - case _ => false - } - private def canFilterLeft(joinType: JoinType): Boolean = joinType match { case Inner | RightOuter => true case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 3b5fc4aea5d8b..89d66034f06cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -194,21 +194,6 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { scanOverhead + cachedOverhead } - /** - * Returns whether an expression is likely to be selective - */ - private def isLikelySelective(e: Expression): Boolean = e match { - case Not(expr) => isLikelySelective(expr) - case And(l, r) => isLikelySelective(l) || isLikelySelective(r) - case Or(l, r) => isLikelySelective(l) && isLikelySelective(r) - case _: StringRegexExpression => true - case _: BinaryComparison => true - case _: In | _: InSet => true - case _: StringPredicate => true - case BinaryPredicate(_) => true - case _: MultiLikeBase => true - case _ => false - } /** * Search a filtering predicate in a given logical plan From 4017f31459cfd97f424ea6ab31bf6fa45617d847 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Fri, 11 Mar 2022 10:56:05 -0500 Subject: [PATCH 03/15] Missing assert --- .../apache/spark/sql/BloomFilterAggregateQuerySuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala index 37b4f5f7c0f6d..fb3388175a92a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.test.SharedSparkSession @@ -196,7 +197,8 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { } test("Test that a query with bloom_filter_agg has partial aggregates") { - spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""") - .queryExecution.executedPlan.collect({case agg: BaseAggregateExec => agg}).size == 2 + assert(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""") + .queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].inputPlan + .collect({case agg: BaseAggregateExec => agg}).size == 2) } } From a67c216e34db7efb9c59b66f2799b277a770cc1b Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 15 Mar 2022 08:04:03 +0800 Subject: [PATCH 04/15] Format code --- .../expressions/BloomFilterMightContain.scala | 2 +- .../aggregate/BloomFilterAggregate.scala | 9 +- .../expressions/objects/objects.scala | 2 +- .../optimizer/InjectRuntimeFilter.scala | 8 +- .../apache/spark/sql/internal/SQLConf.scala | 10 +- .../spark/sql/InjectRuntimeFilterSuite.scala | 107 +++++++++--------- 6 files changed, 69 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala index 9a1cf637e5a73..f069cfa8d1428 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -63,7 +63,7 @@ case class BloomFilterMightContain( TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " + - s"should be either a constant value or a scalar subquery expression") + "should be either a constant value or a scalar subquery expression") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 86d3d62e1c643..de189c1e8e6f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -110,9 +110,10 @@ case class BloomFilterAggregate( override def third: Expression = numBitsExpression - override protected def withNewChildrenInternal(newChild: Expression, - newEstimatedNumItemsExpression: Expression, newNumBitsExpression: Expression) - : BloomFilterAggregate = { + override protected def withNewChildrenInternal( + newChild: Expression, + newEstimatedNumItemsExpression: Expression, + newNumBitsExpression: Expression): BloomFilterAggregate = { copy(child = newChild, estimatedNumItemsExpression = newEstimatedNumItemsExpression, numBitsExpression = newNumBitsExpression) } @@ -176,7 +177,7 @@ object BloomFilterAggregate { class BloomFilterSerDe { final def serialize(obj: BloomFilter): Array[Byte] = { - val size = obj.bitSize()/8 + val size = obj.bitSize() / 8 require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size") val out = new ByteArrayOutputStream(size.intValue()) obj.writeTo(out) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2e96c20bf3c29..2c879beeed623 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TernaryLike -import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, _} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 020ca39dc3837..f8bc60457b08d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -36,7 +36,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J // Wraps `expr` with a hash function if its byte size is larger than an integer. private def mayWrapWithHash(expr: Expression): Expression = { - if (expr.dataType.defaultSize > IntegerType.defaultSize) { + if (expr.dataType.defaultSize > IntegerType.defaultSize) { new Murmur3Hash(Seq(expr)) } else { expr @@ -47,8 +47,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J filterApplicationSideExp: Expression, filterApplicationSidePlan: LogicalPlan, filterCreationSideExp: Expression, - filterCreationSidePlan: LogicalPlan - ): LogicalPlan = { + filterCreationSidePlan: LogicalPlan): LogicalPlan = { require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled) if (conf.runtimeFilterBloomFilterEnabled) { injectBloomFilter( @@ -98,8 +97,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J filterApplicationSideExp: Expression, filterApplicationSidePlan: LogicalPlan, filterCreationSideExp: Expression, - filterCreationSidePlan: LogicalPlan - ): LogicalPlan = { + filterCreationSidePlan: LogicalPlan): LogicalPlan = { require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bf97e73ba3b91..705e14ab48352 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -344,7 +344,7 @@ object SQLConf { val RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED = buildConf("spark.sql.optimizer.runtimeFilter.semiJoinReduction.enabled") .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " + - "to insert a semi join in the other side to reduce the amount of shuffle data") + "to insert a semi join in the other side to reduce the amount of shuffle data.") .version("3.3.0") .booleanConf .createWithDefault(false) @@ -352,7 +352,7 @@ object SQLConf { val RUNTIME_FILTER_NUMBER_THRESHOLD = buildConf("spark.sql.optimizer.runtimeFilter.number.threshold") .doc("The total number of injected runtime filters (non-DPP) for a single " + - "query. This is to prevent driver OOMs with too many Bloom filters") + "query. This is to prevent driver OOMs with too many Bloom filters.") .version("3.3.0") .intConf .checkValue(threshold => threshold >= 0, "The threshold should be >= 0") @@ -361,7 +361,7 @@ object SQLConf { lazy val RUNTIME_BLOOM_FILTER_ENABLED = buildConf("spark.sql.optimizer.runtime.bloomFilter.enabled") .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " + - "to insert a bloom filter in the other side to reduce the amount of shuffle data") + "to insert a bloom filter in the other side to reduce the amount of shuffle data.") .version("3.3.0") .booleanConf .createWithDefault(false) @@ -369,7 +369,7 @@ object SQLConf { val RUNTIME_BLOOM_FILTER_THRESHOLD = buildConf("spark.sql.optimizer.runtime.bloomFilter.threshold") .doc("Size threshold of the bloom filter creation side plan. Estimated size needs to be " + - "under this value to try to inject bloom filter") + "under this value to try to inject bloom filter.") .version("3.3.0") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10MB") @@ -378,7 +378,7 @@ object SQLConf { buildConf("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizethreshold") .doc("Byte size threshold of the Bloom filter application side plan's aggregated scan " + "size. Aggregated scan byte size of the Bloom filter application side needs to be over " + - "this value to inject a bloom filter") + "this value to inject a bloom filter.") .version("3.3.0") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10GB") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index ab1f0e61759d5..a7abd1f3b3f0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -208,6 +208,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp sql("DROP TABLE IF EXISTS bf3") sql("DROP TABLE IF EXISTS bf4") sql("DROP TABLE IF EXISTS bf5part") + sql("DROP TABLE IF EXISTS bf5filtered") } finally { super.afterAll() } @@ -292,64 +293,64 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = false) } - test(s"Runtime semi join reduction: simple") { + test("Runtime semi join reduction: simple") { // Filter creation side is 3409 bytes // Filter application side scan is 3362 bytes withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { - assertRewroteSemiJoin(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 where bf2.a2 = 62") - assertDidNotRewriteSemiJoin(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2") + assertRewroteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2 where bf2.a2 = 62") + assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2") } } - test(s"Runtime semi join reduction: two joins") { + test("Runtime semi join reduction: two joins") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { - assertRewroteSemiJoin(s"select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " + - s"and bf3.c3 = bf2.c2 where bf2.a2 = 5") + assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " + + "and bf3.c3 = bf2.c2 where bf2.a2 = 5") } } - test(s"Runtime semi join reduction: three joins") { + test("Runtime semi join reduction: three joins") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { - assertRewroteSemiJoin(s"select * from bf1 join bf2 join bf3 join bf4 on " + - s"bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5") + assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 join bf4 on " + + "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5") } } - test(s"Runtime semi join reduction: simple expressions only") { + test("Runtime semi join reduction: simple expressions only") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { val squared = (s: Long) => { s * s } spark.udf.register("square", squared) - assertDidNotRewriteSemiJoin(s"select * from bf1 join bf2 on " + - s"bf1.c1 = bf2.c2 where square(bf2.a2) = 62") - assertDidNotRewriteSemiJoin(s"select * from bf1 join bf2 on " + - s"bf1.c1 = square(bf2.c2) where bf2.a2= 62") + assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " + + "bf1.c1 = bf2.c2 where square(bf2.a2) = 62") + assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " + + "bf1.c1 = square(bf2.c2) where bf2.a2= 62") } } - test(s"Runtime bloom filter join: simple") { + test("Runtime bloom filter join: simple") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { - assertRewroteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + - s"where bf2.a2 = 62") - assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2") + assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + "where bf2.a2 = 62") + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2") } } - test(s"Runtime bloom filter join: two filters single join") { + test("Runtime bloom filter join: two filters single join") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { var planDisabled: LogicalPlan = null var planEnabled: LogicalPlan = null var expectedAnswer: Array[Row] = null - val query = s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + - s"bf1.b1 = bf2.b2 where bf2.a2 = 62" + val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + "bf1.b1 = bf2.b2 where bf2.a2 = 62" withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { @@ -366,15 +367,15 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp } } - test(s"Runtime bloom filter join: test the number of filter threshold") { + test("Runtime bloom filter join: test the number of filter threshold") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { var planDisabled: LogicalPlan = null var planEnabled: LogicalPlan = null var expectedAnswer: Array[Row] = null - val query = s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + - s"bf1.b1 = bf2.b2 where bf2.a2 = 62" + val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + "bf1.b1 = bf2.b2 where bf2.a2 = 62" withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { @@ -399,15 +400,15 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp } } - test(s"Runtime bloom filter join: insert one bloom filter per column") { + test("Runtime bloom filter join: insert one bloom filter per column") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { var planDisabled: LogicalPlan = null var planEnabled: LogicalPlan = null var expectedAnswer: Array[Row] = null - val query = s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + - s"bf1.c1 = bf2.b2 where bf2.a2 = 62" + val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + "bf1.c1 = bf2.b2 where bf2.a2 = 62" withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { @@ -424,48 +425,48 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp } } - test(s"Runtime bloom filter join: do not add bloom filter if dpp filter exists " + - s"on the same column") { + test("Runtime bloom filter join: do not add bloom filter if dpp filter exists " + + "on the same column") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { - assertDidNotRewriteWithBloomFilter(s"select * from bf5part join bf2 on " + - s"bf5part.f5 = bf2.c2 where bf2.a2 = 62") + assertDidNotRewriteWithBloomFilter("select * from bf5part join bf2 on " + + "bf5part.f5 = bf2.c2 where bf2.a2 = 62") } } - test(s"Runtime bloom filter join: add bloom filter if dpp filter exists on " + - s"a different column") { + test("Runtime bloom filter join: add bloom filter if dpp filter exists on " + + "a different column") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { - assertRewroteWithBloomFilter(s"select * from bf5part join bf2 on " + - s"bf5part.c5 = bf2.c2 and bf5part.f5 = bf2.f2 where bf2.a2 = 62") + assertRewroteWithBloomFilter("select * from bf5part join bf2 on " + + "bf5part.c5 = bf2.c2 and bf5part.f5 = bf2.f2 where bf2.a2 = 62") } } - test(s"Runtime bloom filter join: BF rewrite triggering threshold test") { + test("Runtime bloom filter join: BF rewrite triggering threshold test") { // Filter creation side data size is 3409 bytes. On the filter application side, an individual // scan's byte size is 3362. withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "4000" ) { - assertRewroteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + - s"where bf2.a2 = 62") + assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + "where bf2.a2 = 62") } withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "50" ) { - assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + - s"where bf2.a2 = 62") + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + "where bf2.a2 = 62") } withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "4000" ) { // Rewrite should not be triggered as the Bloom filter application side scan size is small. - assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on bf1.c1 = bf2.c2 " - + s"where bf2.a2 = 62") + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + "where bf2.a2 = 62") } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32", SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "4000") { @@ -473,30 +474,30 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp // application side matters. `bf5filtered` has 14168 bytes and `bf2` has 3409 bytes. withSQLConf( SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000") { - assertRewroteWithBloomFilter(s"select * from " + - s"(select * from bf5filtered union all select * from bf2) t " + - s"join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") + assertRewroteWithBloomFilter("select * from " + + "(select * from bf5filtered union all select * from bf2) t " + + "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") } withSQLConf( SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "15000") { - assertDidNotRewriteWithBloomFilter(s"select * from " + - s"(select * from bf5filtered union all select * from bf2) t " + - s"join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") + assertDidNotRewriteWithBloomFilter("select * from " + + "(select * from bf5filtered union all select * from bf2) t " + + "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") } } } - test(s"Runtime bloom filter join: simple expressions only") { + test("Runtime bloom filter join: simple expressions only") { withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { val squared = (s: Long) => { s * s } spark.udf.register("square", squared) - assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on " + - s"bf1.c1 = bf2.c2 where square(bf2.a2) = 62" ) - assertDidNotRewriteWithBloomFilter(s"select * from bf1 join bf2 on " + - s"bf1.c1 = square(bf2.c2) where bf2.a2 = 62" ) + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " + + "bf1.c1 = bf2.c2 where square(bf2.a2) = 62" ) + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " + + "bf1.c1 = square(bf2.c2) where bf2.a2 = 62" ) } } } From 722e9118a3779af2f999aa255872fe1e19d6aa3f Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Tue, 15 Mar 2022 20:48:20 -0400 Subject: [PATCH 05/15] Addressed some review comments --- .../apache/spark/util/sketch/BloomFilter.java | 4 +-- .../expressions/BloomFilterMightContain.scala | 34 ++++++++++++++----- .../optimizer/InjectRuntimeFilter.scala | 14 +++++--- .../apache/spark/sql/internal/SQLConf.scala | 10 +++--- .../spark/sql/InjectRuntimeFilterSuite.scala | 8 ++--- 5 files changed, 45 insertions(+), 25 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 2a6e270a91267..9f0454589dc6d 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -166,9 +166,7 @@ int getVersionNumber() { /** * @return the number of set bits in this {@link BloomFilter}. */ - public long cardinality() { - throw new UnsupportedOperationException("Not implemented"); - } + public abstract long cardinality(); /** * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala index f069cfa8d1428..21eb8eb9cc0ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import java.io.ByteArrayInputStream +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.BloomFilter @@ -74,20 +76,34 @@ case class BloomFilterMightContain( valueExpression = newValueExpression) // The bloom filter created from `bloomFilterExpression`. - @transient private var bloomFilter: BloomFilter = _ + @transient private lazy val bloomFilter = { + val bytes = bloomFilterExpression.eval().asInstanceOf[Array[Byte]] + if (bytes == null) null else deserialize(bytes) + } - override def nullSafeEval(bloomFilterBytes: Any, value: Any): Any = { + override def eval(input: InternalRow): Any = { if (bloomFilter == null) { - bloomFilter = deserialize(bloomFilterBytes.asInstanceOf[Array[Byte]]) + null + } else { + val value = valueExpression.eval(input) + if (value == null) null else bloomFilter.mightContainLong(value.asInstanceOf[Long]) } - bloomFilter.mightContainLong(value.asInstanceOf[Long]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val thisObj = ctx.addReferenceObj("thisObj", this) - nullSafeCodeGen(ctx, ev, (bloomFilterBytes, value) => { - s"\n${ev.value} = (Boolean) $thisObj.nullSafeEval($bloomFilterBytes, $value);\n" - }) + if (bloomFilter == null) { + ev.copy(isNull = TrueLiteral, value = JavaCode.defaultLiteral(dataType)) + } else { + val bf = ctx.addReferenceObj("bloomFilter", bloomFilter, classOf[BloomFilter].getName) + val valueEval = valueExpression.genCode(ctx) + ev.copy(code = code""" + ${valueEval.code} + boolean ${ev.isNull} = ${valueEval.isNull}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $bf.mightContainLong((Long)${valueEval.value}); + }""") + } } final def deserialize(bytes: Array[Byte]): BloomFilter = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index f8bc60457b08d..98d0665156e06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -73,7 +73,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J filterCreationSidePlan: LogicalPlan ): LogicalPlan = { // Skip if the filter creation side is too big - if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterBloomFilterThreshold) { + if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) { return filterApplicationSidePlan } val rowCount = filterCreationSidePlan.stats.rowCount @@ -118,7 +118,6 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J * do not add a subquery that might have an expensive computation */ private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { - plan.expressions val ret = plan match { case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => filters.forall(isSimpleExpression) && @@ -183,6 +182,14 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD) } + /** + * Check that: + * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the + * - expression references originate from a single leaf node) + * - The filter creation side has a selective predicate + * - The current join is a shuffle join or a broadcast join that has a shuffle below it + * - The max filterApplicationSide scan size is greater than a configurable threshold + */ private def filteringHasBenefit( filterApplicationSide: LogicalPlan, filterCreationSide: LogicalPlan, @@ -194,8 +201,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J // 2. The filter creation side has a selective predicate // 3. The current join is a shuffle join or a broadcast join that has a shuffle or aggregate // in the filter application side - // 4. The filterApplicationSide is larger than the filterCreationSide by a configurable - // threshold + // 4. The max filterApplicationSide scan size is greater than a configurable threshold findExpressionAndTrackLineageDown(filterApplicationSideExp, filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) && (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 705e14ab48352..3ee4fac8d7193 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -358,7 +358,7 @@ object SQLConf { .checkValue(threshold => threshold >= 0, "The threshold should be >= 0") .createWithDefault(10) - lazy val RUNTIME_BLOOM_FILTER_ENABLED = + val RUNTIME_BLOOM_FILTER_ENABLED = buildConf("spark.sql.optimizer.runtime.bloomFilter.enabled") .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " + "to insert a bloom filter in the other side to reduce the amount of shuffle data.") @@ -366,8 +366,8 @@ object SQLConf { .booleanConf .createWithDefault(false) - val RUNTIME_BLOOM_FILTER_THRESHOLD = - buildConf("spark.sql.optimizer.runtime.bloomFilter.threshold") + val RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD = + buildConf("spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold") .doc("Size threshold of the bloom filter creation side plan. Estimated size needs to be " + "under this value to try to inject bloom filter.") .version("3.3.0") @@ -3760,8 +3760,8 @@ class SQLConf extends Serializable with Logging { def runtimeFilterBloomFilterEnabled: Boolean = getConf(RUNTIME_BLOOM_FILTER_ENABLED) - def runtimeFilterBloomFilterThreshold: Long = - getConf(RUNTIME_BLOOM_FILTER_THRESHOLD) + def runtimeFilterCreationSideThreshold: Long = + getConf(RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD) def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index a7abd1f3b3f0c..a5e27fbfda42a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -448,28 +448,28 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp // scan's byte size is 3362. withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", - SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "4000" + SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000" ) { assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + "where bf2.a2 = 62") } withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", - SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "50" + SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "50" ) { assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + "where bf2.a2 = 62") } withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", - SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "4000" + SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000" ) { // Rewrite should not be triggered as the Bloom filter application side scan size is small. assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + "where bf2.a2 = 62") } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32", - SQLConf.RUNTIME_BLOOM_FILTER_THRESHOLD.key -> "4000") { + SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000") { // Test that the max scan size rather than an individual scan size on the filter // application side matters. `bf5filtered` has 14168 bytes and `bf2` has 3409 bytes. withSQLConf( From 3698a97982f071adedb097ba9a1f5e681dc8a7e7 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Wed, 16 Mar 2022 07:32:06 -0400 Subject: [PATCH 06/15] Addressed review comment: try inject both on the left and right correctly --- .../sql/catalyst/optimizer/InjectRuntimeFilter.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 98d0665156e06..12700794fcc4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -277,11 +277,17 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J !hasDynamicPruningSubquery(left, right, l, r) && !hasRuntimeFilter(newLeft, newRight, l, r) && isSimpleExpression(l) && isSimpleExpression(r)) { + val oldLeft = newLeft + val oldRight = newRight if (canFilterLeft(joinType) && filteringHasBenefit(left, right, l, hint)) { newLeft = injectFilter(l, newLeft, r, right) - filterCounter = filterCounter + 1 - } else if (canFilterRight(joinType) && filteringHasBenefit(right, left, r, hint)) { + } + // Did we actually inject on the left? If not, try on the right + if (newLeft.fastEquals(oldLeft) && canFilterRight(joinType) && + filteringHasBenefit(right, left, r, hint)) { newRight = injectFilter(r, newRight, l, left) + } + if (!newLeft.fastEquals(oldLeft) || !newRight.fastEquals(oldRight)) { filterCounter = filterCounter + 1 } } From 55027968ac6544e80c54dc59827b1aa8ad071e30 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Wed, 16 Mar 2022 09:38:19 -0400 Subject: [PATCH 07/15] Provided back the default implementation for BloomFilter.cardinality() --- .../main/java/org/apache/spark/util/sketch/BloomFilter.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 9f0454589dc6d..2a6e270a91267 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -166,7 +166,9 @@ int getVersionNumber() { /** * @return the number of set bits in this {@link BloomFilter}. */ - public abstract long cardinality(); + public long cardinality() { + throw new UnsupportedOperationException("Not implemented"); + } /** * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close From 504b7d8f09de84c73206c2b6af5a0c5032a6174c Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Wed, 16 Mar 2022 13:46:29 -0400 Subject: [PATCH 08/15] Addressed some review comments for code refactoring --- .../expressions/BloomFilterMightContain.scala | 23 +++++------- .../aggregate/BloomFilterAggregate.scala | 37 +++++++++---------- 2 files changed, 27 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala index 21eb8eb9cc0ec..cf052f865ea90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -49,24 +49,21 @@ case class BloomFilterMightContain( override def dataType: DataType = BooleanType override def checkInputDataTypes(): TypeCheckResult = { - val typeCheckResult = (left.dataType, right.dataType) match { + (left.dataType, right.dataType) match { case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) | - (BinaryType, LongType) => TypeCheckResult.TypeCheckSuccess + (BinaryType, LongType) => + bloomFilterExpression match { + case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess + case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " + + "should be either a constant value or a scalar subquery expression") + } case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + s"been ${BinaryType.simpleString} followed by a value with ${LongType.simpleString}, " + s"but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].") } - if (typeCheckResult.isFailure) { - return typeCheckResult - } - bloomFilterExpression match { - case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess - case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " + - "should be either a constant value or a scalar subquery expression") - } } override protected def withNewChildrenInternal( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index de189c1e8e6f3..019e8d19745fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -60,33 +60,30 @@ case class BloomFilterAggregate( } override def checkInputDataTypes(): TypeCheckResult = { - val typeCheckResult = (first.dataType, second.dataType, third.dataType) match { + (first.dataType, second.dataType, third.dataType) match { case (_, NullType, _) | (_, _, NullType) => TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as size arguments") - case (LongType, LongType, LongType) => TypeCheckResult.TypeCheckSuccess + case (LongType, LongType, LongType) => + if (!estimatedNumItemsExpression.foldable) { + TypeCheckFailure("The estimated number of items provided must be a constant literal") + } else if (estimatedNumItems <= 0L) { + TypeCheckFailure("The estimated number of items must be a positive value " + + s" (current value = $estimatedNumItems)") + } else if (!numBitsExpression.foldable) { + TypeCheckFailure("The number of bits provided must be a constant literal") + } else if (numBits <= 0L) { + TypeCheckFailure("The number of bits must be a positive value " + + s" (current value = $numBits)") + } else { + require(estimatedNumItems <= BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS) + require(numBits <= BloomFilterAggregate.MAX_NUM_BITS) + TypeCheckSuccess + } case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + s"been a ${LongType.simpleString} value followed with two ${LongType.simpleString} size " + s"arguments, but it's [${first.dataType.catalogString}, " + s"${second.dataType.catalogString}, ${third.dataType.catalogString}]") } - if (typeCheckResult.isFailure) { - return typeCheckResult - } - if (!estimatedNumItemsExpression.foldable) { - TypeCheckFailure("The estimated number of items provided must be a constant literal") - } else if (estimatedNumItems <= 0L) { - TypeCheckFailure("The estimated number of items must be a positive value " + - s" (current value = $estimatedNumItems)") - } else if (!numBitsExpression.foldable) { - TypeCheckFailure("The number of bits provided must be a constant literal") - } else if (numBits <= 0L) { - TypeCheckFailure("The number of bits must be a positive value " + - s" (current value = $numBits)") - } else { - require(estimatedNumItems <= BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS) - require(numBits <= BloomFilterAggregate.MAX_NUM_BITS) - TypeCheckSuccess - } } override def nullable: Boolean = true From 0bf41c4fb63abe95a783060f21e5b44b88d648ef Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Wed, 16 Mar 2022 13:48:27 -0400 Subject: [PATCH 09/15] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala Co-authored-by: Prashant Singh <35593236+singhpk234@users.noreply.github.com> --- .../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 12700794fcc4c..6f1378507b1df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -196,7 +196,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J filterApplicationSideExp: Expression, hint: JoinHint): Boolean = { // Check that: - // 1. The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the + // 1. The filterApplicationSideJoinExp can be pushed down through joins and aggregates (i.e the // expression references originate from a single leaf node) // 2. The filter creation side has a selective predicate // 3. The current join is a shuffle join or a broadcast join that has a shuffle or aggregate From db4ac3b904a9ed32dc87300a1e11e5493cbc7962 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Wed, 16 Mar 2022 13:53:43 -0400 Subject: [PATCH 10/15] more review comments --- .../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 6f1378507b1df..a97be652c19e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -149,10 +149,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J } private def probablyHasShuffle(plan: LogicalPlan): Boolean = { - plan.collect { + plan.collectFirst { case j@Join(left, right, _, _, hint) - if !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && - !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) => j + if isProbablyShuffleJoin(left, right, hint) => j case a: Aggregate => a }.nonEmpty } From e6113765ad2d91294387d0d1c9ddeec4e03b113c Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Wed, 16 Mar 2022 21:43:43 -0400 Subject: [PATCH 11/15] More review comments --- .../aggregate/BloomFilterAggregate.scala | 63 +++++++------------ .../apache/spark/sql/internal/SQLConf.scala | 29 +++++++++ .../sql/BloomFilterAggregateQuerySuite.scala | 6 +- 3 files changed, 56 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 019e8d19745fb..3ae99e86da126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.BloomFilter @@ -55,8 +56,8 @@ case class BloomFilterAggregate( } def this(child: Expression) = { - this(child, Literal(BloomFilterAggregate.DEFAULT_EXPECTED_NUM_ITEMS), - Literal(BloomFilterAggregate.DEFAULT_NUM_BITS)) + this(child, Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_EXPECTED_NUM_ITEMS)), + Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_NUM_BITS))) } override def checkInputDataTypes(): TypeCheckResult = { @@ -75,8 +76,9 @@ case class BloomFilterAggregate( TypeCheckFailure("The number of bits must be a positive value " + s" (current value = $numBits)") } else { - require(estimatedNumItems <= BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS) - require(numBits <= BloomFilterAggregate.MAX_NUM_BITS) + require(estimatedNumItems <= + SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS)) + require(numBits <= SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) TypeCheckSuccess } case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + @@ -94,12 +96,12 @@ case class BloomFilterAggregate( // Mark as lazy so that `estimatedNumItems` is not evaluated during tree transformation. private lazy val estimatedNumItems: Long = Math.min(estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue, - BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS) + SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS)) // Mark as lazy so that `numBits` is not evaluated during tree transformation. private lazy val numBits: Long = Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, - BloomFilterAggregate.MAX_NUM_BITS) + SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) override def first: Expression = child @@ -148,47 +150,28 @@ case class BloomFilterAggregate( copy(inputAggBufferOffset = newOffset) override def serialize(obj: BloomFilter): Array[Byte] = { - BloomFilterAggregate.serde.serialize(obj) + BloomFilterAggregate.serialize(obj) } override def deserialize(bytes: Array[Byte]): BloomFilter = { - BloomFilterAggregate.serde.deserialize(bytes) + BloomFilterAggregate.deserialize(bytes) } } object BloomFilterAggregate { - - val DEFAULT_EXPECTED_NUM_ITEMS: Long = 1000000L // Default 1M distinct items - - val MAX_ALLOWED_NUM_ITEMS: Long = 4000000L // At most 4M distinct items - - val DEFAULT_NUM_BITS: Long = 8388608 // Default 1MB - - val MAX_NUM_BITS: Long = 67108864 // At most 8MB - - /** - * Serializer/Deserializer for class [[BloomFilter]] - * - * This class is thread safe. - */ - class BloomFilterSerDe { - - final def serialize(obj: BloomFilter): Array[Byte] = { - val size = obj.bitSize() / 8 - require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size") - val out = new ByteArrayOutputStream(size.intValue()) - obj.writeTo(out) - out.close() - out.toByteArray - } - - final def deserialize(bytes: Array[Byte]): BloomFilter = { - val in = new ByteArrayInputStream(bytes) - val bloomFilter = BloomFilter.readFrom(in) - in.close() - bloomFilter - } + final def serialize(obj: BloomFilter): Array[Byte] = { + val size = obj.bitSize() / 8 + require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size") + val out = new ByteArrayOutputStream(size.intValue()) + obj.writeTo(out) + out.close() + out.toByteArray } - val serde: BloomFilterSerDe = new BloomFilterSerDe + final def deserialize(bytes: Array[Byte]): BloomFilter = { + val in = new ByteArrayInputStream(bytes) + val bloomFilter = BloomFilter.readFrom(in) + in.close() + bloomFilter + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3ee4fac8d7193..af911c21b2927 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -383,6 +383,35 @@ object SQLConf { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10GB") + val RUNTIME_BLOOM_FILTER_EXPECTED_NUM_ITEMS = + buildConf("spark.sql.optimizer.runtime.bloomFilter.expectedNumItems") + .doc("The default number of expected items for the runtime bloomfilter") + .version("3.3.0") + .longConf + .createWithDefault(1000000L) + + val RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS = + buildConf("spark.sql.optimizer.runtime.bloomFilter.maxNumItems") + .doc("The max allowed number of expected items for the runtime bloom filter") + .version("3.3.0") + .longConf + .createWithDefault(4000000L) + + + val RUNTIME_BLOOM_FILTER_NUM_BITS = + buildConf("spark.sql.optimizer.runtime.bloomFilter.numBits") + .doc("The default number of bits to use for the runtime bloom filter") + .version("3.3.0") + .longConf + .createWithDefault(8388608L) + + val RUNTIME_BLOOM_FILTER_MAX_NUM_BITS = + buildConf("spark.sql.optimizer.runtime.bloomFilter.maxNumBits") + .doc("The max number of bits to use for the runtime bloom filter") + .version("3.3.0") + .longConf + .createWithDefault(67108864L) + val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala index fb3388175a92a..995e5c7d502c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession /** @@ -46,11 +47,12 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) test("Test bloom_filter_agg and might_contain") { + val conf = SQLConf.get val table = "bloom_filter_test" for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, - BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS)) { + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) { for (numBits <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, - BloomFilterAggregate.MAX_NUM_BITS)) { + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))) { val sqlString = s""" |SELECT every(might_contain( | (SELECT bloom_filter_agg(col, From c0a56c6835ec1f1fb5665d38c7ffa24066cffe5b Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Mon, 21 Mar 2022 12:13:12 -0400 Subject: [PATCH 12/15] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala Co-authored-by: Wenchen Fan --- .../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index a97be652c19e4..f7fa7d3e3d849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -218,7 +218,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J } // This checks if there is already a DPP filter, as this rule is called just after DPP. - def hasDynamicPruningSubquery(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + def hasDynamicPruningSubquery( + left: LogicalPlan, + right: LogicalPlan, + leftKey: Expression, rightKey: Expression): Boolean = { (left, right) match { case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => From d4e032c6344e30e2494364dee43fed0e3cab3c18 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Mon, 21 Mar 2022 12:13:29 -0400 Subject: [PATCH 13/15] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala Co-authored-by: Wenchen Fan --- .../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index f7fa7d3e3d849..f84af08f5f69a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -233,7 +233,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J } } - def hasBloomFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + def hasBloomFilter( + left: LogicalPlan, + right: LogicalPlan, + leftKey: Expression, rightKey: Expression): Boolean = { findBloomFilterWithExp(left, leftKey) || findBloomFilterWithExp(right, rightKey) } From 7cd444f5f0fde1a02ef43a2b54557d8868844d97 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Mon, 21 Mar 2022 14:18:03 -0400 Subject: [PATCH 14/15] Review comments --- .../aggregate/BloomFilterAggregate.scala | 4 +++- .../catalyst/optimizer/InjectRuntimeFilter.scala | 11 ++--------- .../spark/sql/BloomFilterAggregateQuerySuite.scala | 13 +++++++++++-- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 3ae99e86da126..c734bca3ef8d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -160,7 +160,9 @@ case class BloomFilterAggregate( object BloomFilterAggregate { final def serialize(obj: BloomFilter): Array[Byte] = { - val size = obj.bitSize() / 8 + // BloomFilterImpl.writeTo() writes 2 integers (version number and num hash functions), hence + // the +8 + val size = (obj.bitSize() / 8) + 8 require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size") val out = new ByteArrayOutputStream(size.intValue()) obj.writeTo(out) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index f84af08f5f69a..3eca96de00e27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -184,7 +184,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J /** * Check that: * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the - * - expression references originate from a single leaf node) + * expression references originate from a single leaf node) * - The filter creation side has a selective predicate * - The current join is a shuffle join or a broadcast join that has a shuffle below it * - The max filterApplicationSide scan size is greater than a configurable threshold @@ -194,13 +194,6 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J filterCreationSide: LogicalPlan, filterApplicationSideExp: Expression, hint: JoinHint): Boolean = { - // Check that: - // 1. The filterApplicationSideJoinExp can be pushed down through joins and aggregates (i.e the - // expression references originate from a single leaf node) - // 2. The filter creation side has a selective predicate - // 3. The current join is a shuffle join or a broadcast join that has a shuffle or aggregate - // in the filter application side - // 4. The max filterApplicationSide scan size is greater than a configurable threshold findExpressionAndTrackLineageDown(filterApplicationSideExp, filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) && (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) || @@ -297,7 +290,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J } } }) - Join(newLeft, newRight, joinType, join.condition, hint) + join.withNewChildren(Seq(newLeft, newRight)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala index 995e5c7d502c7..025593be4c959 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -32,8 +32,11 @@ import org.apache.spark.sql.test.SharedSparkSession class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { import testImplicits._ + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + val funcId_might_contain = new FunctionIdentifier("might_contain") + // Register 'bloom_filter_agg' to builtin. - FunctionRegistry.builtin.registerFunction(new FunctionIdentifier("bloom_filter_agg"), + FunctionRegistry.builtin.registerFunction(funcId_bloom_filter_agg, new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), (children: Seq[Expression]) => children.size match { case 1 => new BloomFilterAggregate(children.head) @@ -42,10 +45,16 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { }) // Register 'might_contain' to builtin. - FunctionRegistry.builtin.registerFunction(new FunctionIdentifier("might_contain"), + FunctionRegistry.builtin.registerFunction(funcId_might_contain, new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + override def afterAll(): Unit = { + FunctionRegistry.builtin.dropFunction(funcId_bloom_filter_agg) + FunctionRegistry.builtin.dropFunction(funcId_might_contain) + super.afterAll() + } + test("Test bloom_filter_agg and might_contain") { val conf = SQLConf.get val table = "bloom_filter_test" From 015d578f8f31a11810655db926bab5bd32311026 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Tue, 22 Mar 2022 12:56:19 -0400 Subject: [PATCH 15/15] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala Co-authored-by: Wenchen Fan --- .../spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 3eca96de00e27..35d0189f64651 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -70,8 +70,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J filterApplicationSideExp: Expression, filterApplicationSidePlan: LogicalPlan, filterCreationSideExp: Expression, - filterCreationSidePlan: LogicalPlan - ): LogicalPlan = { + filterCreationSidePlan: LogicalPlan): LogicalPlan = { // Skip if the filter creation side is too big if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) { return filterApplicationSidePlan