Skip to content

Commit b378fff

Browse files
committed
Add null filtering logic to the aggregate function along with tests
1 parent 291a13d commit b378fff

6 files changed

Lines changed: 71 additions & 41 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,25 @@ case class AnyAgg(child: Expression) extends DeclarativeAggregate with ImplicitC
4040

4141
private lazy val some = AttributeReference("some", BooleanType)()
4242

43-
private lazy val emptySet = AttributeReference("emptySet", BooleanType)()
43+
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
4444

45-
override lazy val aggBufferAttributes = some :: emptySet :: Nil
45+
override lazy val aggBufferAttributes = some :: valueSet :: Nil
4646

4747
override lazy val initialValues: Seq[Expression] = Seq(
48-
Literal(false),
49-
Literal(true)
48+
/* some = */ Literal.create(false, BooleanType),
49+
/* valueSet = */ Literal.create(false, BooleanType)
5050
)
5151

5252
override lazy val updateExpressions: Seq[Expression] = Seq(
53-
Or(some, Coalesce(Seq(child, Literal(false)))),
54-
Literal(false)
53+
/* some = */ Or(some, If (child.isNull, some, child)),
54+
/* valueSet = */ valueSet || child.isNotNull
5555
)
5656

5757
override lazy val mergeExpressions: Seq[Expression] = Seq(
58-
Or(some.left, some.right),
59-
And(emptySet.left, emptySet.right)
58+
/* some = */ Or(some.left, some.right),
59+
/* valueSet */ valueSet.right || valueSet.left
6060
)
6161

62-
override lazy val evaluateExpression: Expression = And(!emptySet, some)
62+
override lazy val evaluateExpression: Expression =
63+
If (valueSet, some, Literal.create(null, BooleanType))
6364
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,25 @@ case class Every(child: Expression) extends DeclarativeAggregate with ImplicitCa
4040

4141
private lazy val every = AttributeReference("every", BooleanType)()
4242

43-
private lazy val emptySet = AttributeReference("emptySet", BooleanType)()
43+
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
4444

45-
override lazy val aggBufferAttributes = every :: emptySet :: Nil
45+
override lazy val aggBufferAttributes = every :: valueSet :: Nil
4646

4747
override lazy val initialValues: Seq[Expression] = Seq(
48-
Literal(true),
49-
Literal(true)
48+
/* every = */ Literal.create(true, BooleanType),
49+
/* valueSet = */ Literal.create(false, BooleanType)
5050
)
5151

5252
override lazy val updateExpressions: Seq[Expression] = Seq(
53-
And(every, Coalesce(Seq(child, Literal(false)))),
54-
Literal(false)
53+
/* every = */ And(every, If (child.isNull, every, child)),
54+
/* valueSet = */ valueSet || child.isNotNull
5555
)
5656

5757
override lazy val mergeExpressions: Seq[Expression] = Seq(
58-
And(every.left, every.right),
59-
And(emptySet.left, emptySet.right)
58+
/* every = */ And(every.left, every.right),
59+
/* valueSet */ valueSet.right || valueSet.left
6060
)
6161

62-
override lazy val evaluateExpression: Expression = And(!emptySet, every)
62+
override lazy val evaluateExpression: Expression =
63+
If (valueSet, every, Literal.create(null, BooleanType))
6364
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,42 +26,43 @@ class AnyTestSuite extends SparkFunSuite {
2626
val evaluator = DeclarativeAggregateEvaluator(AnyAgg(input), Seq(input))
2727

2828
test("empty buffer") {
29-
assert(evaluator.initialize() === InternalRow(false, true))
29+
assert(evaluator.initialize() === InternalRow(false, false))
3030
}
3131

3232
test("update") {
3333
val result = evaluator.update(
3434
InternalRow(true),
3535
InternalRow(false),
3636
InternalRow(true))
37-
assert(result === InternalRow(true, false))
37+
assert(result === InternalRow(true, true))
3838
}
3939

4040
test("merge") {
4141
// Empty merge
4242
val p0 = evaluator.initialize()
43-
assert(evaluator.merge(p0) === InternalRow(false, true))
43+
assert(evaluator.merge(p0) === InternalRow(false, false))
4444

4545
// Single merge
4646
val p1 = evaluator.update(InternalRow(true), InternalRow(true))
47-
assert(evaluator.merge(p1) === InternalRow(true, false))
47+
assert(evaluator.merge(p1) === InternalRow(true, true))
4848

4949
// Multiple merges.
5050
val p2 = evaluator.update(InternalRow(false), InternalRow(null))
51-
assert(evaluator.merge(p1, p2) === InternalRow(true, false))
51+
assert(evaluator.merge(p1, p2) === InternalRow(true, true))
5252

5353
// Empty partitions (p0 is empty)
54-
assert(evaluator.merge(p0, p2) === InternalRow(false, false))
55-
assert(evaluator.merge(p2, p1, p0) === InternalRow(true, false))
54+
assert(evaluator.merge(p0, p2) === InternalRow(false, true))
55+
assert(evaluator.merge(p2, p1, p0) === InternalRow(true, true))
5656
}
5757

5858
test("eval") {
5959
// Null Eval
60-
assert(evaluator.eval(InternalRow(null, true)) === InternalRow(false))
60+
assert(evaluator.eval(InternalRow(true, false)) === InternalRow(null))
61+
assert(evaluator.eval(InternalRow(false, false)) === InternalRow(null))
6162

6263
// Empty Eval
6364
val p0 = evaluator.initialize()
64-
assert(evaluator.eval(p0) === InternalRow(false))
65+
assert(evaluator.eval(p0) === InternalRow(null))
6566

6667
// Update - Eval
6768
val p1 = evaluator.update(InternalRow(true), InternalRow(null))
@@ -76,5 +77,4 @@ class AnyTestSuite extends SparkFunSuite {
7677
val m2 = evaluator.merge(p2, p1, p0)
7778
assert(evaluator.eval(m2) === InternalRow(true))
7879
}
79-
8080
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,42 +26,43 @@ class EveryTestSuite extends SparkFunSuite {
2626
val evaluator = DeclarativeAggregateEvaluator(Every(input), Seq(input))
2727

2828
test("empty buffer") {
29-
assert(evaluator.initialize() === InternalRow(true, true))
29+
assert(evaluator.initialize() === InternalRow(true, false))
3030
}
3131

3232
test("update") {
3333
val result = evaluator.update(
3434
InternalRow(true),
3535
InternalRow(false),
3636
InternalRow(true))
37-
assert(result === InternalRow(false, false))
37+
assert(result === InternalRow(false, true))
3838
}
3939

4040
test("merge") {
4141
// Empty merge
4242
val p0 = evaluator.initialize()
43-
assert(evaluator.merge(p0) === InternalRow(true, true))
43+
assert(evaluator.merge(p0) === InternalRow(true, false))
4444

4545
// Single merge
4646
val p1 = evaluator.update(InternalRow(true), InternalRow(true))
47-
assert(evaluator.merge(p1) === InternalRow(true, false))
47+
assert(evaluator.merge(p1) === InternalRow(true, true))
4848

4949
// Multiple merges.
5050
val p2 = evaluator.update(InternalRow(true), InternalRow(null))
51-
assert(evaluator.merge(p1, p2) === InternalRow(false, false))
51+
assert(evaluator.merge(p1, p2) === InternalRow(true, true))
5252

5353
// Empty partitions (p0 is empty)
54-
assert(evaluator.merge(p1, p0, p2) === InternalRow(false, false))
55-
assert(evaluator.merge(p2, p1, p0) === InternalRow(false, false))
54+
assert(evaluator.merge(p1, p0, p2) === InternalRow(true, true))
55+
assert(evaluator.merge(p2, p1, p0) === InternalRow(true, true))
5656
}
5757

5858
test("eval") {
5959
// Null Eval
60-
assert(evaluator.eval(InternalRow(null, true)) === InternalRow(false))
60+
assert(evaluator.eval(InternalRow(true, false)) === InternalRow(null))
61+
assert(evaluator.eval(InternalRow(false, false)) === InternalRow(null))
6162

6263
// Empty Eval
6364
val p0 = evaluator.initialize()
64-
assert(evaluator.eval(p0) === InternalRow(false))
65+
assert(evaluator.eval(p0) === InternalRow(null))
6566

6667
// Update - Eval
6768
val p1 = evaluator.update(InternalRow(true), InternalRow(true))

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,11 +736,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
736736
Seq(Row(1, false), Row(2, true), Row(3, false)))
737737
}
738738

739+
test("every null values") {
740+
val df = Seq[(java.lang.Integer, java.lang.Boolean)](
741+
(1, true), (1, false),
742+
(2, true),
743+
(3, false), (3, null),
744+
(4, null), (4, null))
745+
.toDF("a", "b")
746+
checkAnswer(
747+
df.groupBy("a").agg(every('b)),
748+
Seq(Row(1, false), Row(2, true), Row(3, false), Row(4, null)))
749+
}
750+
739751
test("every empty table") {
740752
val df = Seq.empty[(Int, Boolean)].toDF("a", "b")
741753
checkAnswer(
742754
df.agg(every('b)),
743-
Seq(Row(false)))
755+
Seq(Row(null)))
744756
}
745757

746758
test("any") {
@@ -755,7 +767,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
755767
val df = Seq.empty[(Int, Boolean)].toDF("a", "b")
756768
checkAnswer(
757769
df.agg(any('b)),
758-
Seq(Row(false)))
770+
Seq(Row(null)))
771+
}
772+
773+
test("any/some null values") {
774+
val df = Seq[(java.lang.Integer, java.lang.Boolean)] (
775+
(1, true), (1, false),
776+
(2, true),
777+
(3, true), (3, false), (3, null),
778+
(4, null), (4, null))
779+
.toDF("a", "b")
780+
checkAnswer(
781+
df.groupBy("a").agg(any('b)),
782+
Seq(Row(1, true), Row(2, true), Row(3, true), Row(4, null)))
783+
checkAnswer(
784+
df.groupBy("a").agg(some('b)),
785+
Seq(Row(1, true), Row(2, true), Row(3, true), Row(4, null)))
759786
}
760787

761788
test("some") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -709,8 +709,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
709709
Row("b", true, true, true, true),
710710
Row("b", true, true, true, true),
711711
Row("c", false, false, false, false),
712-
Row("d", true, false, true, true),
713-
Row("d", null, false, true, true)
712+
Row("d", true, true, true, true),
713+
Row("d", null, true, true, true)
714714
)
715715
)
716716
}

0 commit comments

Comments
 (0)