From 1be858870a71d6fe1694a267f777a965c17b64ec Mon Sep 17 00:00:00 2001 From: Wenting Zheng Date: Tue, 2 Feb 2021 01:24:13 +0000 Subject: [PATCH 1/8] Semijoin support --- src/enclave/Enclave/ExpressionEvaluation.h | 6 +++++ src/enclave/Enclave/Join.cpp | 23 ++++++++++++++++++- .../berkeley/cs/rise/opaque/strategies.scala | 9 +++++++- .../cs/rise/opaque/OpaqueOperatorTests.scala | 9 ++++++++ 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/enclave/Enclave/ExpressionEvaluation.h b/src/enclave/Enclave/ExpressionEvaluation.h index 737f92ac83..a2b495dd60 100644 --- a/src/enclave/Enclave/ExpressionEvaluation.h +++ b/src/enclave/Enclave/ExpressionEvaluation.h @@ -1583,6 +1583,7 @@ class FlatbuffersJoinExprEvaluator { } const tuix::JoinExpr* join_expr = flatbuffers::GetRoot(buf); + join_type = join_expr->join_type(); if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) { throw std::runtime_error("Mismatched join key lengths"); @@ -1639,8 +1640,13 @@ class FlatbuffersJoinExprEvaluator { return true; } + tuix::JoinType get_join_type() { + return join_type; + } + private: flatbuffers::FlatBufferBuilder builder; + tuix::JoinType join_type; std::vector> left_key_evaluators; std::vector> right_key_evaluators; }; diff --git a/src/enclave/Enclave/Join.cpp b/src/enclave/Enclave/Join.cpp index b8797e8b45..3125418ba6 100644 --- a/src/enclave/Enclave/Join.cpp +++ b/src/enclave/Enclave/Join.cpp @@ -59,6 +59,8 @@ void non_oblivious_sort_merge_join( last_primary_of_group.set(row); } + bool leftsemi_add_row = true; + while (r.has_next()) { const tuix::Row *current = r.next(); @@ -73,6 +75,7 @@ void non_oblivious_sort_merge_join( primary_group.clear(); primary_group.append(current); last_primary_of_group.set(current); + leftsemi_add_row = true; } } else { // Output the joined rows resulting from this foreign row @@ -92,7 +95,25 @@ void non_oblivious_sort_merge_join( + to_string(current)); } - w.append(primary, current); + tuix::JoinType join_type = join_expr_eval.get_join_type(); + switch (join_type) { + case tuix::JoinType_Inner: + w.append(primary, current); + break; + + case tuix::JoinType_LeftSemi: + if (leftsemi_add_row) { + w.append(primary, current); + leftsemi_add_row = false; + } + break; + + default: + throw std::runtime_error(std::string("Join type ") + + tuix::EnumNameJoinType(join_type) + + std::string(" is not supported")); + break; + } } } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index f26551553d..4b39340e5d 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.execution.SparkPlan import edu.berkeley.cs.rise.opaque.execution._ @@ -90,7 +92,12 @@ object OpaqueOperators extends Strategy { case Some(condition) => EncryptedFilterExec(condition, tagsDropped) case None => tagsDropped } - filtered :: Nil + + joinType match { + case Inner => filtered :: Nil + case LeftSemi => EncryptedProjectExec(left.output, filtered) :: Nil + case _ => Nil + } case a @ PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) => diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index 79e1bee374..beabe2ea15 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -316,6 +316,15 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => p.join(f, $"join_col_1" === $"join_col_2").collect.toSet } + testAgainstSpark("left semi join") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, i.toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1") + df.collect + } + def abc(i: Int): String = (i % 3) match { case 0 => "A" case 1 => "B" From e7e328209c50b48e26b372bcdf4f7ba1adb3af12 Mon Sep 17 00:00:00 2001 From: Wenting Zheng Date: Tue, 2 Feb 2021 04:44:21 +0000 Subject: [PATCH 2/8] Fix bug with multiple rows matching for the same value for the left table --- src/enclave/Enclave/Join.cpp | 3 ++- .../edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/enclave/Enclave/Join.cpp b/src/enclave/Enclave/Join.cpp index 3125418ba6..cf40dad0c9 100644 --- a/src/enclave/Enclave/Join.cpp +++ b/src/enclave/Enclave/Join.cpp @@ -104,7 +104,6 @@ void non_oblivious_sort_merge_join( case tuix::JoinType_LeftSemi: if (leftsemi_add_row) { w.append(primary, current); - leftsemi_add_row = false; } break; @@ -115,6 +114,8 @@ void non_oblivious_sort_merge_join( break; } } + + leftsemi_add_row = false; } } } diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index beabe2ea15..fa6a3b2a97 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -317,11 +317,11 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => } testAgainstSpark("left semi join") { securityLevel => - val p_data = for (i <- 1 to 16) yield (i, i.toString, i * 10) + val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10) val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") - val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1") + val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1", $"id") df.collect } From ed419b09e6384b3faaa9c61dfc1989051228e256 Mon Sep 17 00:00:00 2001 From: Wenting Zheng Date: Tue, 2 Feb 2021 04:46:50 +0000 Subject: [PATCH 3/8] Add back TPC-H q4 --- src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala index d003c835f3..b99e7b18d7 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala @@ -40,7 +40,7 @@ trait TPCHTests extends OpaqueTestsBase { self => tpch.query(3, securityLevel, spark.sqlContext, numPartitions).collect.toSet } - testAgainstSpark("TPC-H 4", ignore) { securityLevel => + testAgainstSpark("TPC-H 4") { securityLevel => tpch.query(4, securityLevel, spark.sqlContext, numPartitions).collect.toSet } From 24cbd8281f2a84d10b8df01ddd41b5ccdf47fcde Mon Sep 17 00:00:00 2001 From: Wenting Zheng Date: Thu, 4 Feb 2021 04:54:22 +0000 Subject: [PATCH 4/8] Antijoin impl --- src/enclave/Enclave/Join.cpp | 55 ++++++++++++------- .../berkeley/cs/rise/opaque/strategies.scala | 14 +++-- .../cs/rise/opaque/OpaqueOperatorTests.scala | 9 +++ 3 files changed, 51 insertions(+), 27 deletions(-) diff --git a/src/enclave/Enclave/Join.cpp b/src/enclave/Enclave/Join.cpp index cf40dad0c9..8402372ffc 100644 --- a/src/enclave/Enclave/Join.cpp +++ b/src/enclave/Enclave/Join.cpp @@ -47,6 +47,7 @@ void non_oblivious_sort_merge_join( uint8_t **output_rows, size_t *output_rows_length) { FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length); + tuix::JoinType join_type = join_expr_eval.get_join_type(); RowReader r(BufferRefView(input_rows, input_rows_length)); RowReader j(BufferRefView(join_row, join_row_length)); RowWriter w; @@ -59,7 +60,7 @@ void non_oblivious_sort_merge_join( last_primary_of_group.set(row); } - bool leftsemi_add_row = true; + bool pk_fk_match = false; while (r.has_next()) { const tuix::Row *current = r.next(); @@ -71,11 +72,22 @@ void non_oblivious_sort_merge_join( primary_group.append(current); last_primary_of_group.set(current); } else { - // Advance to a new group + // If a new primary group is encountered + if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) { + auto primary_group_buffer = primary_group.output_buffer(); + RowReader primary_group_reader(primary_group_buffer.view()); + + while (primary_group_reader.has_next()) { + const tuix::Row *primary = primary_group_reader.next(); + w.append(primary); + } + } + primary_group.clear(); primary_group.append(current); last_primary_of_group.set(current); - leftsemi_add_row = true; + + pk_fk_match = false; } } else { // Output the joined rows resulting from this foreign row @@ -95,30 +107,31 @@ void non_oblivious_sort_merge_join( + to_string(current)); } - tuix::JoinType join_type = join_expr_eval.get_join_type(); - switch (join_type) { - case tuix::JoinType_Inner: - w.append(primary, current); - break; - - case tuix::JoinType_LeftSemi: - if (leftsemi_add_row) { - w.append(primary, current); - } - break; - - default: - throw std::runtime_error(std::string("Join type ") + - tuix::EnumNameJoinType(join_type) + - std::string(" is not supported")); - break; + if (join_type == tuix::JoinType_Inner) { + w.append(primary, current); + } else if (join_type == tuix::JoinType_LeftSemi) { + if (!pk_fk_match) { + w.append(primary); + } } } - leftsemi_add_row = false; + pk_fk_match = true; + } else { + pk_fk_match = false; } } } + if (!pk_fk_match && join_type == tuix::JoinType_LeftAnti) { + auto primary_group_buffer = primary_group.output_buffer(); + RowReader primary_group_reader(primary_group_buffer.view()); + + while (primary_group_reader.has_next()) { + const tuix::Row *primary = primary_group_reader.next(); + w.append(primary); + } + } + w.output_buffer(output_rows, output_rows_length); } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index 4b39340e5d..d5243414dd 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.LeftAnti import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.execution.SparkPlan @@ -87,17 +88,18 @@ object OpaqueOperators extends Strategy { rightProjSchema.map(_.toAttribute), (leftProjSchema ++ rightProjSchema).map(_.toAttribute), sorted) - val tagsDropped = EncryptedProjectExec(dropTags(left.output, right.output), joined) + + val tagsDropped = joinType match { + case Inner => EncryptedProjectExec(dropTags(left.output, right.output), joined) + case LeftSemi | LeftAnti => EncryptedProjectExec(left.output, joined) + } + val filtered = condition match { case Some(condition) => EncryptedFilterExec(condition, tagsDropped) case None => tagsDropped } - joinType match { - case Inner => filtered :: Nil - case LeftSemi => EncryptedProjectExec(left.output, filtered) :: Nil - case _ => Nil - } + filtered :: Nil case a @ PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) => diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index fa6a3b2a97..c5f49d3424 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -325,6 +325,15 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.collect } + testAgainstSpark("left anti join") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 9).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, ((i % 7) + 1).toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" === $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + def abc(i: Int): String = (i % 3) match { case 0 => "A" case 1 => "B" From dcee39049de3fd8b86a9b13468a9b8432be55992 Mon Sep 17 00:00:00 2001 From: Wenting Zheng Date: Fri, 5 Feb 2021 17:54:58 +0000 Subject: [PATCH 5/8] WIP for join --- src/enclave/Enclave/Join.cpp | 8 ++++++-- .../cs/rise/opaque/OpaqueOperatorTests.scala | 15 ++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/enclave/Enclave/Join.cpp b/src/enclave/Enclave/Join.cpp index 8402372ffc..8b17376c45 100644 --- a/src/enclave/Enclave/Join.cpp +++ b/src/enclave/Enclave/Join.cpp @@ -64,6 +64,7 @@ void non_oblivious_sort_merge_join( while (r.has_next()) { const tuix::Row *current = r.next(); + print(current); if (join_expr_eval.is_primary(current)) { if (last_primary_of_group.get() @@ -110,6 +111,7 @@ void non_oblivious_sort_merge_join( if (join_type == tuix::JoinType_Inner) { w.append(primary, current); } else if (join_type == tuix::JoinType_LeftSemi) { + // Only output the pk group ONCE if (!pk_fk_match) { w.append(primary); } @@ -118,12 +120,14 @@ void non_oblivious_sort_merge_join( pk_fk_match = true; } else { - pk_fk_match = false; + // If pk_fk_match were true, and the code got to here, then that means the group match has not been "cleared" yet + // It will be processed when the code advances to the next pk group + pk_fk_match &= true; } } } - if (!pk_fk_match && join_type == tuix::JoinType_LeftAnti) { + if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) { auto primary_group_buffer = primary_group.output_buffer(); RowReader primary_group_reader(primary_group_buffer.view()); diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index c5f49d3424..432e8d77df 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -325,9 +325,18 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.collect } - testAgainstSpark("left anti join") { securityLevel => - val p_data = for (i <- 1 to 16) yield (i, (i % 9).toString, i * 10) - val f_data = for (i <- 1 to 32) yield (i, ((i % 7) + 1).toString, i * 10) + testAgainstSpark("left anti join 1") { securityLevel => + val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) + val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" === $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + + testAgainstSpark("left anti join 2") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") val df = p.join(f, $"join_col_1" === $"join_col_2", "left_anti").sort($"join_col_1", $"id") From 449d621898588398e213f352515cb067406787dc Mon Sep 17 00:00:00 2001 From: Wenting Zheng Date: Sat, 6 Feb 2021 03:12:55 +0000 Subject: [PATCH 6/8] Join rewrite --- src/enclave/App/App.cpp | 45 +------- src/enclave/App/SGXEnclave.h | 6 +- src/enclave/Enclave/Enclave.cpp | 19 ---- src/enclave/Enclave/Enclave.edl | 6 -- src/enclave/Enclave/Join.cpp | 43 -------- src/enclave/Enclave/Join.h | 6 -- .../opaque/execution/EncryptedSortExec.scala | 102 +++++++++++------- .../cs/rise/opaque/execution/SGXEnclave.scala | 4 +- .../cs/rise/opaque/execution/operators.scala | 26 ++--- .../berkeley/cs/rise/opaque/strategies.scala | 7 +- .../cs/rise/opaque/OpaqueOperatorTests.scala | 8 +- 11 files changed, 86 insertions(+), 186 deletions(-) diff --git a/src/enclave/App/App.cpp b/src/enclave/App/App.cpp index 6817863e69..64013d2ab7 100644 --- a/src/enclave/App/App.cpp +++ b/src/enclave/App/App.cpp @@ -518,47 +518,9 @@ JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEncla return ret; } -JNIEXPORT jbyteArray JNICALL -Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ScanCollectLastPrimary( - JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows) { - (void)obj; - - jboolean if_copy; - - uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr); - uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy); - - uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows); - uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy); - - uint8_t *output_rows = nullptr; - size_t output_rows_length = 0; - - if (input_rows_ptr == nullptr) { - ocall_throw("ScanCollectLastPrimary: JNI failed to get input byte array."); - } else { - oe_check_and_time("Scan Collect Last Primary", - ecall_scan_collect_last_primary( - (oe_enclave_t*)eid, - join_expr_ptr, join_expr_length, - input_rows_ptr, input_rows_length, - &output_rows, &output_rows_length)); - } - - jbyteArray ret = env->NewByteArray(output_rows_length); - env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows); - free(output_rows); - - env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0); - env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0); - - return ret; -} - JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( - JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows, - jbyteArray join_row) { + JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows) { (void)obj; jboolean if_copy; @@ -569,9 +531,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows); uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy); - uint32_t join_row_length = (uint32_t) env->GetArrayLength(join_row); - uint8_t *join_row_ptr = (uint8_t *) env->GetByteArrayElements(join_row, &if_copy); - uint8_t *output_rows = nullptr; size_t output_rows_length = 0; @@ -583,7 +542,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( (oe_enclave_t*)eid, join_expr_ptr, join_expr_length, input_rows_ptr, input_rows_length, - join_row_ptr, join_row_length, &output_rows, &output_rows_length)); } @@ -593,7 +551,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0); env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0); - env->ReleaseByteArrayElements(join_row, (jbyte *) join_row_ptr, 0); return ret; } diff --git a/src/enclave/App/SGXEnclave.h b/src/enclave/App/SGXEnclave.h index c2168ab6e3..2b74c42763 100644 --- a/src/enclave/App/SGXEnclave.h +++ b/src/enclave/App/SGXEnclave.h @@ -37,13 +37,9 @@ extern "C" { JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ExternalSort( JNIEnv *, jobject, jlong, jbyteArray, jbyteArray); - JNIEXPORT jbyteArray JNICALL - Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ScanCollectLastPrimary( - JNIEnv *, jobject, jlong, jbyteArray, jbyteArray); - JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( - JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray); + JNIEnv *, jobject, jlong, jbyteArray, jbyteArray); JNIEXPORT jobject JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate( diff --git a/src/enclave/Enclave/Enclave.cpp b/src/enclave/Enclave/Enclave.cpp index 41eda5ec27..e9342875b2 100644 --- a/src/enclave/Enclave/Enclave.cpp +++ b/src/enclave/Enclave/Enclave.cpp @@ -145,35 +145,16 @@ void ecall_external_sort(uint8_t *sort_order, size_t sort_order_length, } } -void ecall_scan_collect_last_primary(uint8_t *join_expr, size_t join_expr_length, - uint8_t *input_rows, size_t input_rows_length, - uint8_t **output_rows, size_t *output_rows_length) { - // Guard against operating on arbitrary enclave memory - assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1); - __builtin_ia32_lfence(); - - try { - scan_collect_last_primary(join_expr, join_expr_length, - input_rows, input_rows_length, - output_rows, output_rows_length); - } catch (const std::runtime_error &e) { - ocall_throw(e.what()); - } -} - void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, - uint8_t *join_row, size_t join_row_length, uint8_t **output_rows, size_t *output_rows_length) { // Guard against operating on arbitrary enclave memory assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1); - assert(oe_is_outside_enclave(join_row, join_row_length) == 1); __builtin_ia32_lfence(); try { non_oblivious_sort_merge_join(join_expr, join_expr_length, input_rows, input_rows_length, - join_row, join_row_length, output_rows, output_rows_length); } catch (const std::runtime_error &e) { ocall_throw(e.what()); diff --git a/src/enclave/Enclave/Enclave.edl b/src/enclave/Enclave/Enclave.edl index 5546840b31..0225c64efa 100644 --- a/src/enclave/Enclave/Enclave.edl +++ b/src/enclave/Enclave/Enclave.edl @@ -43,15 +43,9 @@ enclave { [user_check] uint8_t *input_rows, size_t input_rows_length, [out] uint8_t **output_rows, [out] size_t *output_rows_length); - public void ecall_scan_collect_last_primary( - [in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length, - [user_check] uint8_t *input_rows, size_t input_rows_length, - [out] uint8_t **output_rows, [out] size_t *output_rows_length); - public void ecall_non_oblivious_sort_merge_join( [in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length, [user_check] uint8_t *input_rows, size_t input_rows_length, - [user_check] uint8_t *join_row, size_t join_row_length, [out] uint8_t **output_rows, [out] size_t *output_rows_length); public void ecall_non_oblivious_aggregate( diff --git a/src/enclave/Enclave/Join.cpp b/src/enclave/Enclave/Join.cpp index 8b17376c45..828c963d40 100644 --- a/src/enclave/Enclave/Join.cpp +++ b/src/enclave/Enclave/Join.cpp @@ -5,66 +5,23 @@ #include "FlatbuffersWriters.h" #include "common.h" -void scan_collect_last_primary( - uint8_t *join_expr, size_t join_expr_length, - uint8_t *input_rows, size_t input_rows_length, - uint8_t **output_rows, size_t *output_rows_length) { - - FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length); - RowReader r(BufferRefView(input_rows, input_rows_length)); - RowWriter w; - - FlatbuffersTemporaryRow last_primary; - - // Accumulate all primary table rows from the same group as the last primary row into `w`. - // - // Because our distributed sorting algorithm uses range partitioning over the join keys, all - // primary rows belonging to the same group will be colocated in the same partition. (The - // corresponding foreign rows may be in the same partition or the next partition.) Therefore it is - // sufficient to send primary rows at most one partition forward. - while (r.has_next()) { - const tuix::Row *row = r.next(); - if (join_expr_eval.is_primary(row)) { - if (!last_primary.get() || !join_expr_eval.is_same_group(last_primary.get(), row)) { - w.clear(); - last_primary.set(row); - } - - w.append(row); - } else { - w.clear(); - last_primary.set(nullptr); - } - } - - w.output_buffer(output_rows, output_rows_length); -} - void non_oblivious_sort_merge_join( uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, - uint8_t *join_row, size_t join_row_length, uint8_t **output_rows, size_t *output_rows_length) { FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length); tuix::JoinType join_type = join_expr_eval.get_join_type(); RowReader r(BufferRefView(input_rows, input_rows_length)); - RowReader j(BufferRefView(join_row, join_row_length)); RowWriter w; RowWriter primary_group; FlatbuffersTemporaryRow last_primary_of_group; - while (j.has_next()) { - const tuix::Row *row = j.next(); - primary_group.append(row); - last_primary_of_group.set(row); - } bool pk_fk_match = false; while (r.has_next()) { const tuix::Row *current = r.next(); - print(current); if (join_expr_eval.is_primary(current)) { if (last_primary_of_group.get() diff --git a/src/enclave/Enclave/Join.h b/src/enclave/Enclave/Join.h index 83d34ccce5..b380909027 100644 --- a/src/enclave/Enclave/Join.h +++ b/src/enclave/Enclave/Join.h @@ -4,15 +4,9 @@ #ifndef JOIN_H #define JOIN_H -void scan_collect_last_primary( - uint8_t *join_expr, size_t join_expr_length, - uint8_t *input_rows, size_t input_rows_length, - uint8_t **output_rows, size_t *output_rows_length); - void non_oblivious_sort_merge_join( uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, - uint8_t *join_row, size_t join_row_length, uint8_t **output_rows, size_t *output_rows_length); #endif diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala index 1ef97bce91..59a528b95a 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala @@ -30,59 +30,81 @@ case class EncryptedSortExec(order: Seq[SortOrder], isGlobal: Boolean, child: Sp override def executeBlocked(): RDD[Block] = { val orderSer = Utils.serializeSortOrder(order, child.output) - EncryptedSortExec.sort(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), orderSer, isGlobal) + if (isGlobal) { + EncryptedSortExec.sampleAndPartition(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), orderSer) + } + EncryptedSortExec.localSort(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), orderSer) + } +} + +case class EncryptedRangePartitionExec(order: Seq[SortOrder], child: SparkPlan) + extends UnaryExecNode with OpaqueOperatorExec { + + override def output: Seq[Attribute] = child.output + + override def executeBlocked(): RDD[Block] = { + val orderSer = Utils.serializeSortOrder(order, child.output) + EncryptedSortExec.sampleAndPartition(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), orderSer) } } object EncryptedSortExec { import Utils.time - def sort(childRDD: RDD[Block], orderSer: Array[Byte], isGlobal: Boolean): RDD[Block] = { + def sampleAndPartition(childRDD: RDD[Block], orderSer: Array[Byte]): RDD[Block] = { Utils.ensureCached(childRDD) - time("force child of EncryptedSort") { childRDD.count } + time("force child of sampleAndPartition") { childRDD.count } - time("non-oblivious sort") { + time("non-oblivious range partitioning") { val numPartitions = childRDD.partitions.length - val result = - if (numPartitions <= 1 || !isGlobal) { - childRDD.map { block => + if (numPartitions <= 1) { + Utils.ensureCached(childRDD) + childRDD.count + childRDD + } else { + // Collect a sample of the input rows + val sampled = time("non-oblivious sort - Sample") { + Utils.concatEncryptedBlocks(childRDD.map { block => val (enclave, eid) = Utils.initEnclave() - val sortedRows = enclave.ExternalSort(eid, orderSer, block.bytes) - Block(sortedRows) - } - } else { - // Collect a sample of the input rows - val sampled = time("non-oblivious sort - Sample") { - Utils.concatEncryptedBlocks(childRDD.map { block => - val (enclave, eid) = Utils.initEnclave() - val sampledBlock = enclave.Sample(eid, block.bytes) - Block(sampledBlock) - }.collect) - } - // Find range boundaries parceled out to a single worker - val boundaries = time("non-oblivious sort - FindRangeBounds") { - childRDD.context.parallelize(Array(sampled.bytes), 1).map { sampledBytes => - val (enclave, eid) = Utils.initEnclave() - enclave.FindRangeBounds(eid, orderSer, numPartitions, sampledBytes) - }.collect.head - } - // Broadcast the range boundaries and use them to partition the input - childRDD.flatMap { block => + val sampledBlock = enclave.Sample(eid, block.bytes) + Block(sampledBlock) + }.collect) + } + // Find range boundaries parceled out to a single worker + val boundaries = time("non-oblivious sort - FindRangeBounds") { + childRDD.context.parallelize(Array(sampled.bytes), 1).map { sampledBytes => val (enclave, eid) = Utils.initEnclave() - val partitions = enclave.PartitionForSort( - eid, orderSer, numPartitions, block.bytes, boundaries) - partitions.zipWithIndex.map { - case (partition, i) => (i, Block(partition)) - } + enclave.FindRangeBounds(eid, orderSer, numPartitions, sampledBytes) + }.collect.head + } + // Broadcast the range boundaries and use them to partition the input + // Shuffle the input to achieve range partitioning and sort locally + val result = childRDD.flatMap { block => + val (enclave, eid) = Utils.initEnclave() + val partitions = enclave.PartitionForSort( + eid, orderSer, numPartitions, block.bytes, boundaries) + partitions.zipWithIndex.map { + case (partition, i) => (i, Block(partition)) } - // Shuffle the input to achieve range partitioning and sort locally - .groupByKey(numPartitions).map { - case (i, blocks) => - val (enclave, eid) = Utils.initEnclave() - Block(enclave.ExternalSort( - eid, orderSer, Utils.concatEncryptedBlocks(blocks.toSeq).bytes)) - } + }.groupByKey(numPartitions).map { case (i, blocks) => + Utils.concatEncryptedBlocks(blocks.toSeq) } + + Utils.ensureCached(result) + result.count + result + } + } + } + + def localSort(childRDD: RDD[Block], orderSer: Array[Byte]): RDD[Block] = { + Utils.ensureCached(childRDD) + time("non-oblivious local sort") { + val result = childRDD.map { block => + val (enclave, eid) = Utils.initEnclave() + val sortedRows = enclave.ExternalSort(eid, orderSer, block.bytes) + Block(sortedRows) + } Utils.ensureCached(result) result.count() result diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala index aef4ba8303..b49090ced1 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala @@ -39,10 +39,8 @@ class SGXEnclave extends java.io.Serializable { boundaries: Array[Byte]): Array[Array[Byte]] @native def ExternalSort(eid: Long, order: Array[Byte], input: Array[Byte]): Array[Byte] - @native def ScanCollectLastPrimary( - eid: Long, joinExpr: Array[Byte], input: Array[Byte]): Array[Byte] @native def NonObliviousSortMergeJoin( - eid: Long, joinExpr: Array[Byte], input: Array[Byte], joinRow: Array[Byte]): Array[Byte] + eid: Long, joinExpr: Array[Byte], input: Array[Byte]): Array[Byte] @native def NonObliviousAggregate( eid: Long, aggOp: Array[Byte], inputRows: Array[Byte], isPartial: Boolean): (Array[Byte]) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index e40acbff78..3292e7acca 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -286,22 +286,18 @@ case class EncryptedSortMergeJoinExec( child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedSortMergeJoinExec") { childRDD => - val lastPrimaryRows = childRDD.map { block => + // val lastPrimaryRows = childRDD.map { block => + // val (enclave, eid) = Utils.initEnclave() + // Block(enclave.ScanCollectLastPrimary(eid, joinExprSer, block.bytes)) + // }.collect + // val shifted = Utils.emptyBlock +: lastPrimaryRows.dropRight(1) + // assert(shifted.size == childRDD.partitions.length) + // val processedJoinRowsRDD = + // sparkContext.parallelize(shifted, childRDD.partitions.length) + + childRDD.map { block => val (enclave, eid) = Utils.initEnclave() - Block(enclave.ScanCollectLastPrimary(eid, joinExprSer, block.bytes)) - }.collect - val shifted = Utils.emptyBlock +: lastPrimaryRows.dropRight(1) - assert(shifted.size == childRDD.partitions.length) - val processedJoinRowsRDD = - sparkContext.parallelize(shifted, childRDD.partitions.length) - - childRDD.zipPartitions(processedJoinRowsRDD) { (blockIter, joinRowIter) => - (blockIter.toSeq, joinRowIter.toSeq) match { - case (Seq(block), Seq(joinRow)) => - val (enclave, eid) = Utils.initEnclave() - Iterator(Block(enclave.NonObliviousSortMergeJoin( - eid, joinExprSer, block.bytes, joinRow.bytes))) - } + Block(enclave.NonObliviousSortMergeJoin(eid, joinExprSer, block.bytes)) } } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index d5243414dd..416f4a0769 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -79,7 +79,12 @@ object OpaqueOperators extends Strategy { val leftProj = EncryptedProjectExec(leftProjSchema, planLater(left)) val rightProj = EncryptedProjectExec(rightProjSchema, planLater(right)) val unioned = EncryptedUnionExec(leftProj, rightProj) - val sorted = EncryptedSortExec(sortForJoin(leftKeysProj, tag, unioned.output), true, unioned) + // We partition based on the join keys only, so that rows from both the left and the right tables that match + // will colocate to the same partition + val partitionOrder = leftKeysProj.map(k => SortOrder(k, Ascending)) + val partitioned = EncryptedRangePartitionExec(partitionOrder, unioned) + val sortOrder = sortForJoin(leftKeysProj, tag, partitioned.output) + val sorted = EncryptedSortExec(sortOrder, false, partitioned) val joined = EncryptedSortMergeJoinExec( joinType, leftKeysProj, diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index 432e8d77df..03edfce279 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -305,7 +305,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val f_data = for (i <- 1 to 256 - 16) yield ((i % 16).toString, (i * 10).toString, i.toFloat) val p = makeDF(p_data, securityLevel, "pk", "x") val f = makeDF(f_data, securityLevel, "fk", "x", "y") - p.join(f, $"pk" === $"fk").collect.toSet + val df = p.join(f, $"pk" === $"fk").collect.toSet } testAgainstSpark("non-foreign-key join") { securityLevel => @@ -319,9 +319,9 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => testAgainstSpark("left semi join") { securityLevel => val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10) val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) - val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") - val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") - val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1", $"id") + val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") + val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1", $"id1") df.collect } From 546f904d1581fdb76095666e99f843f9d3f0f973 Mon Sep 17 00:00:00 2001 From: Wenting Zheng Date: Mon, 8 Feb 2021 02:59:18 +0000 Subject: [PATCH 7/8] Fix sort bug --- .../cs/rise/opaque/execution/EncryptedSortExec.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala index 59a528b95a..0cb71c87ff 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala @@ -30,10 +30,12 @@ case class EncryptedSortExec(order: Seq[SortOrder], isGlobal: Boolean, child: Sp override def executeBlocked(): RDD[Block] = { val orderSer = Utils.serializeSortOrder(order, child.output) - if (isGlobal) { - EncryptedSortExec.sampleAndPartition(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), orderSer) + val childRDD = child.asInstanceOf[OpaqueOperatorExec].executeBlocked() + val partitionedRDD = isGlobal match { + case true => EncryptedSortExec.sampleAndPartition(childRDD, orderSer) + case false => childRDD } - EncryptedSortExec.localSort(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), orderSer) + EncryptedSortExec.localSort(partitionedRDD, orderSer) } } @@ -86,7 +88,8 @@ object EncryptedSortExec { partitions.zipWithIndex.map { case (partition, i) => (i, Block(partition)) } - }.groupByKey(numPartitions).map { case (i, blocks) => + }.groupByKey(numPartitions).map { + case (i, blocks) => Utils.concatEncryptedBlocks(blocks.toSeq) } From 4acc5d2fa41da370598ea3e7cd9a082dd6f04d54 Mon Sep 17 00:00:00 2001 From: Wenting Zheng Date: Tue, 9 Feb 2021 04:01:37 +0000 Subject: [PATCH 8/8] Address comments --- .../opaque/execution/EncryptedSortExec.scala | 81 ++++++++----------- .../cs/rise/opaque/execution/operators.scala | 22 ++--- .../berkeley/cs/rise/opaque/strategies.scala | 1 - 3 files changed, 46 insertions(+), 58 deletions(-) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala index 0cb71c87ff..a32e7c10e8 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala @@ -57,60 +57,49 @@ object EncryptedSortExec { Utils.ensureCached(childRDD) time("force child of sampleAndPartition") { childRDD.count } - time("non-oblivious range partitioning") { - val numPartitions = childRDD.partitions.length - if (numPartitions <= 1) { - Utils.ensureCached(childRDD) - childRDD.count - childRDD - } else { - // Collect a sample of the input rows - val sampled = time("non-oblivious sort - Sample") { - Utils.concatEncryptedBlocks(childRDD.map { block => - val (enclave, eid) = Utils.initEnclave() - val sampledBlock = enclave.Sample(eid, block.bytes) - Block(sampledBlock) - }.collect) - } - // Find range boundaries parceled out to a single worker - val boundaries = time("non-oblivious sort - FindRangeBounds") { - childRDD.context.parallelize(Array(sampled.bytes), 1).map { sampledBytes => - val (enclave, eid) = Utils.initEnclave() - enclave.FindRangeBounds(eid, orderSer, numPartitions, sampledBytes) - }.collect.head - } - // Broadcast the range boundaries and use them to partition the input - // Shuffle the input to achieve range partitioning and sort locally - val result = childRDD.flatMap { block => + val numPartitions = childRDD.partitions.length + if (numPartitions <= 1) { + childRDD + } else { + // Collect a sample of the input rows + val sampled = time("non-oblivious sort - Sample") { + Utils.concatEncryptedBlocks(childRDD.map { block => + val (enclave, eid) = Utils.initEnclave() + val sampledBlock = enclave.Sample(eid, block.bytes) + Block(sampledBlock) + }.collect) + } + // Find range boundaries parceled out to a single worker + val boundaries = time("non-oblivious sort - FindRangeBounds") { + childRDD.context.parallelize(Array(sampled.bytes), 1).map { sampledBytes => val (enclave, eid) = Utils.initEnclave() - val partitions = enclave.PartitionForSort( - eid, orderSer, numPartitions, block.bytes, boundaries) - partitions.zipWithIndex.map { - case (partition, i) => (i, Block(partition)) - } - }.groupByKey(numPartitions).map { - case (i, blocks) => - Utils.concatEncryptedBlocks(blocks.toSeq) + enclave.FindRangeBounds(eid, orderSer, numPartitions, sampledBytes) + }.collect.head + } + // Broadcast the range boundaries and use them to partition the input + // Shuffle the input to achieve range partitioning and sort locally + val result = childRDD.flatMap { block => + val (enclave, eid) = Utils.initEnclave() + val partitions = enclave.PartitionForSort( + eid, orderSer, numPartitions, block.bytes, boundaries) + partitions.zipWithIndex.map { + case (partition, i) => (i, Block(partition)) } - - Utils.ensureCached(result) - result.count - result + }.groupByKey(numPartitions).map { + case (i, blocks) => + Utils.concatEncryptedBlocks(blocks.toSeq) } + result } } def localSort(childRDD: RDD[Block], orderSer: Array[Byte]): RDD[Block] = { Utils.ensureCached(childRDD) - time("non-oblivious local sort") { - val result = childRDD.map { block => - val (enclave, eid) = Utils.initEnclave() - val sortedRows = enclave.ExternalSort(eid, orderSer, block.bytes) - Block(sortedRows) - } - Utils.ensureCached(result) - result.count() - result + val result = childRDD.map { block => + val (enclave, eid) = Utils.initEnclave() + val sortedRows = enclave.ExternalSort(eid, orderSer, block.bytes) + Block(sortedRows) } + result } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index 3292e7acca..7ed6862b6b 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -26,7 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.LeftAnti +import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan @@ -274,9 +277,15 @@ case class EncryptedSortMergeJoinExec( rightKeys: Seq[Expression], leftSchema: Seq[Attribute], rightSchema: Seq[Attribute], - output: Seq[Attribute], child: SparkPlan) - extends UnaryExecNode with OpaqueOperatorExec { + extends UnaryExecNode with OpaqueOperatorExec { + + override def output: Seq[Attribute] = { + joinType match { + case Inner => (leftSchema ++ rightSchema).map(_.toAttribute) + case LeftSemi | LeftAnti => leftSchema.map(_.toAttribute) + } + } override def executeBlocked(): RDD[Block] = { val joinExprSer = Utils.serializeJoinExpression( @@ -286,15 +295,6 @@ case class EncryptedSortMergeJoinExec( child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedSortMergeJoinExec") { childRDD => - // val lastPrimaryRows = childRDD.map { block => - // val (enclave, eid) = Utils.initEnclave() - // Block(enclave.ScanCollectLastPrimary(eid, joinExprSer, block.bytes)) - // }.collect - // val shifted = Utils.emptyBlock +: lastPrimaryRows.dropRight(1) - // assert(shifted.size == childRDD.partitions.length) - // val processedJoinRowsRDD = - // sparkContext.parallelize(shifted, childRDD.partitions.length) - childRDD.map { block => val (enclave, eid) = Utils.initEnclave() Block(enclave.NonObliviousSortMergeJoin(eid, joinExprSer, block.bytes)) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index 416f4a0769..0c8f188369 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -91,7 +91,6 @@ object OpaqueOperators extends Strategy { rightKeysProj, leftProjSchema.map(_.toAttribute), rightProjSchema.map(_.toAttribute), - (leftProjSchema ++ rightProjSchema).map(_.toAttribute), sorted) val tagsDropped = joinType match {