diff --git a/src/enclave/App/App.cpp b/src/enclave/App/App.cpp index 6817863e69..64013d2ab7 100644 --- a/src/enclave/App/App.cpp +++ b/src/enclave/App/App.cpp @@ -518,47 +518,9 @@ JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEncla return ret; } -JNIEXPORT jbyteArray JNICALL -Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ScanCollectLastPrimary( - JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows) { - (void)obj; - - jboolean if_copy; - - uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr); - uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy); - - uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows); - uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy); - - uint8_t *output_rows = nullptr; - size_t output_rows_length = 0; - - if (input_rows_ptr == nullptr) { - ocall_throw("ScanCollectLastPrimary: JNI failed to get input byte array."); - } else { - oe_check_and_time("Scan Collect Last Primary", - ecall_scan_collect_last_primary( - (oe_enclave_t*)eid, - join_expr_ptr, join_expr_length, - input_rows_ptr, input_rows_length, - &output_rows, &output_rows_length)); - } - - jbyteArray ret = env->NewByteArray(output_rows_length); - env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows); - free(output_rows); - - env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0); - env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0); - - return ret; -} - JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( - JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows, - jbyteArray join_row) { + JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray input_rows) { (void)obj; jboolean if_copy; @@ -569,9 +531,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows); uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy); - uint32_t join_row_length = (uint32_t) env->GetArrayLength(join_row); - uint8_t *join_row_ptr = (uint8_t *) env->GetByteArrayElements(join_row, &if_copy); - uint8_t *output_rows = nullptr; size_t output_rows_length = 0; @@ -583,7 +542,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( (oe_enclave_t*)eid, join_expr_ptr, join_expr_length, input_rows_ptr, input_rows_length, - join_row_ptr, join_row_length, &output_rows, &output_rows_length)); } @@ -593,7 +551,6 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0); env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0); - env->ReleaseByteArrayElements(join_row, (jbyte *) join_row_ptr, 0); return ret; } diff --git a/src/enclave/App/SGXEnclave.h b/src/enclave/App/SGXEnclave.h index c2168ab6e3..2b74c42763 100644 --- a/src/enclave/App/SGXEnclave.h +++ b/src/enclave/App/SGXEnclave.h @@ -37,13 +37,9 @@ extern "C" { JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ExternalSort( JNIEnv *, jobject, jlong, jbyteArray, jbyteArray); - JNIEXPORT jbyteArray JNICALL - Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_ScanCollectLastPrimary( - JNIEnv *, jobject, jlong, jbyteArray, jbyteArray); - JNIEXPORT jbyteArray JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( - JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray); + JNIEnv *, jobject, jlong, jbyteArray, jbyteArray); JNIEXPORT jobject JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate( diff --git a/src/enclave/Enclave/Enclave.cpp b/src/enclave/Enclave/Enclave.cpp index 41eda5ec27..e9342875b2 100644 --- a/src/enclave/Enclave/Enclave.cpp +++ b/src/enclave/Enclave/Enclave.cpp @@ -145,35 +145,16 @@ void ecall_external_sort(uint8_t *sort_order, size_t sort_order_length, } } -void ecall_scan_collect_last_primary(uint8_t *join_expr, size_t join_expr_length, - uint8_t *input_rows, size_t input_rows_length, - uint8_t **output_rows, size_t *output_rows_length) { - // Guard against operating on arbitrary enclave memory - assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1); - __builtin_ia32_lfence(); - - try { - scan_collect_last_primary(join_expr, join_expr_length, - input_rows, input_rows_length, - output_rows, output_rows_length); - } catch (const std::runtime_error &e) { - ocall_throw(e.what()); - } -} - void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, - uint8_t *join_row, size_t join_row_length, uint8_t **output_rows, size_t *output_rows_length) { // Guard against operating on arbitrary enclave memory assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1); - assert(oe_is_outside_enclave(join_row, join_row_length) == 1); __builtin_ia32_lfence(); try { non_oblivious_sort_merge_join(join_expr, join_expr_length, input_rows, input_rows_length, - join_row, join_row_length, output_rows, output_rows_length); } catch (const std::runtime_error &e) { ocall_throw(e.what()); diff --git a/src/enclave/Enclave/Enclave.edl b/src/enclave/Enclave/Enclave.edl index 5546840b31..0225c64efa 100644 --- a/src/enclave/Enclave/Enclave.edl +++ b/src/enclave/Enclave/Enclave.edl @@ -43,15 +43,9 @@ enclave { [user_check] uint8_t *input_rows, size_t input_rows_length, [out] uint8_t **output_rows, [out] size_t *output_rows_length); - public void ecall_scan_collect_last_primary( - [in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length, - [user_check] uint8_t *input_rows, size_t input_rows_length, - [out] uint8_t **output_rows, [out] size_t *output_rows_length); - public void ecall_non_oblivious_sort_merge_join( [in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length, [user_check] uint8_t *input_rows, size_t input_rows_length, - [user_check] uint8_t *join_row, size_t join_row_length, [out] uint8_t **output_rows, [out] size_t *output_rows_length); public void ecall_non_oblivious_aggregate( diff --git a/src/enclave/Enclave/ExpressionEvaluation.h b/src/enclave/Enclave/ExpressionEvaluation.h index 80475b877f..9405ddd34f 100644 --- a/src/enclave/Enclave/ExpressionEvaluation.h +++ b/src/enclave/Enclave/ExpressionEvaluation.h @@ -1682,6 +1682,7 @@ class FlatbuffersJoinExprEvaluator { } const tuix::JoinExpr* join_expr = flatbuffers::GetRoot(buf); + join_type = join_expr->join_type(); if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) { throw std::runtime_error("Mismatched join key lengths"); @@ -1738,8 +1739,13 @@ class FlatbuffersJoinExprEvaluator { return true; } + tuix::JoinType get_join_type() { + return join_type; + } + private: flatbuffers::FlatBufferBuilder builder; + tuix::JoinType join_type; std::vector> left_key_evaluators; std::vector> right_key_evaluators; }; diff --git a/src/enclave/Enclave/Join.cpp b/src/enclave/Enclave/Join.cpp index b8797e8b45..828c963d40 100644 --- a/src/enclave/Enclave/Join.cpp +++ b/src/enclave/Enclave/Join.cpp @@ -5,59 +5,20 @@ #include "FlatbuffersWriters.h" #include "common.h" -void scan_collect_last_primary( - uint8_t *join_expr, size_t join_expr_length, - uint8_t *input_rows, size_t input_rows_length, - uint8_t **output_rows, size_t *output_rows_length) { - - FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length); - RowReader r(BufferRefView(input_rows, input_rows_length)); - RowWriter w; - - FlatbuffersTemporaryRow last_primary; - - // Accumulate all primary table rows from the same group as the last primary row into `w`. - // - // Because our distributed sorting algorithm uses range partitioning over the join keys, all - // primary rows belonging to the same group will be colocated in the same partition. (The - // corresponding foreign rows may be in the same partition or the next partition.) Therefore it is - // sufficient to send primary rows at most one partition forward. - while (r.has_next()) { - const tuix::Row *row = r.next(); - if (join_expr_eval.is_primary(row)) { - if (!last_primary.get() || !join_expr_eval.is_same_group(last_primary.get(), row)) { - w.clear(); - last_primary.set(row); - } - - w.append(row); - } else { - w.clear(); - last_primary.set(nullptr); - } - } - - w.output_buffer(output_rows, output_rows_length); -} - void non_oblivious_sort_merge_join( uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, - uint8_t *join_row, size_t join_row_length, uint8_t **output_rows, size_t *output_rows_length) { FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length); + tuix::JoinType join_type = join_expr_eval.get_join_type(); RowReader r(BufferRefView(input_rows, input_rows_length)); - RowReader j(BufferRefView(join_row, join_row_length)); RowWriter w; RowWriter primary_group; FlatbuffersTemporaryRow last_primary_of_group; - while (j.has_next()) { - const tuix::Row *row = j.next(); - primary_group.append(row); - last_primary_of_group.set(row); - } + + bool pk_fk_match = false; while (r.has_next()) { const tuix::Row *current = r.next(); @@ -69,10 +30,22 @@ void non_oblivious_sort_merge_join( primary_group.append(current); last_primary_of_group.set(current); } else { - // Advance to a new group + // If a new primary group is encountered + if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) { + auto primary_group_buffer = primary_group.output_buffer(); + RowReader primary_group_reader(primary_group_buffer.view()); + + while (primary_group_reader.has_next()) { + const tuix::Row *primary = primary_group_reader.next(); + w.append(primary); + } + } + primary_group.clear(); primary_group.append(current); last_primary_of_group.set(current); + + pk_fk_match = false; } } else { // Output the joined rows resulting from this foreign row @@ -92,11 +65,34 @@ void non_oblivious_sort_merge_join( + to_string(current)); } - w.append(primary, current); + if (join_type == tuix::JoinType_Inner) { + w.append(primary, current); + } else if (join_type == tuix::JoinType_LeftSemi) { + // Only output the pk group ONCE + if (!pk_fk_match) { + w.append(primary); + } + } } + + pk_fk_match = true; + } else { + // If pk_fk_match were true, and the code got to here, then that means the group match has not been "cleared" yet + // It will be processed when the code advances to the next pk group + pk_fk_match &= true; } } } + if (join_type == tuix::JoinType_LeftAnti && !pk_fk_match) { + auto primary_group_buffer = primary_group.output_buffer(); + RowReader primary_group_reader(primary_group_buffer.view()); + + while (primary_group_reader.has_next()) { + const tuix::Row *primary = primary_group_reader.next(); + w.append(primary); + } + } + w.output_buffer(output_rows, output_rows_length); } diff --git a/src/enclave/Enclave/Join.h b/src/enclave/Enclave/Join.h index 83d34ccce5..b380909027 100644 --- a/src/enclave/Enclave/Join.h +++ b/src/enclave/Enclave/Join.h @@ -4,15 +4,9 @@ #ifndef JOIN_H #define JOIN_H -void scan_collect_last_primary( - uint8_t *join_expr, size_t join_expr_length, - uint8_t *input_rows, size_t input_rows_length, - uint8_t **output_rows, size_t *output_rows_length); - void non_oblivious_sort_merge_join( uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, - uint8_t *join_row, size_t join_row_length, uint8_t **output_rows, size_t *output_rows_length); #endif diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala index 1ef97bce91..a32e7c10e8 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala @@ -30,62 +30,76 @@ case class EncryptedSortExec(order: Seq[SortOrder], isGlobal: Boolean, child: Sp override def executeBlocked(): RDD[Block] = { val orderSer = Utils.serializeSortOrder(order, child.output) - EncryptedSortExec.sort(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), orderSer, isGlobal) + val childRDD = child.asInstanceOf[OpaqueOperatorExec].executeBlocked() + val partitionedRDD = isGlobal match { + case true => EncryptedSortExec.sampleAndPartition(childRDD, orderSer) + case false => childRDD + } + EncryptedSortExec.localSort(partitionedRDD, orderSer) + } +} + +case class EncryptedRangePartitionExec(order: Seq[SortOrder], child: SparkPlan) + extends UnaryExecNode with OpaqueOperatorExec { + + override def output: Seq[Attribute] = child.output + + override def executeBlocked(): RDD[Block] = { + val orderSer = Utils.serializeSortOrder(order, child.output) + EncryptedSortExec.sampleAndPartition(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), orderSer) } } object EncryptedSortExec { import Utils.time - def sort(childRDD: RDD[Block], orderSer: Array[Byte], isGlobal: Boolean): RDD[Block] = { + def sampleAndPartition(childRDD: RDD[Block], orderSer: Array[Byte]): RDD[Block] = { Utils.ensureCached(childRDD) - time("force child of EncryptedSort") { childRDD.count } + time("force child of sampleAndPartition") { childRDD.count } - time("non-oblivious sort") { - val numPartitions = childRDD.partitions.length - val result = - if (numPartitions <= 1 || !isGlobal) { - childRDD.map { block => - val (enclave, eid) = Utils.initEnclave() - val sortedRows = enclave.ExternalSort(eid, orderSer, block.bytes) - Block(sortedRows) - } - } else { - // Collect a sample of the input rows - val sampled = time("non-oblivious sort - Sample") { - Utils.concatEncryptedBlocks(childRDD.map { block => - val (enclave, eid) = Utils.initEnclave() - val sampledBlock = enclave.Sample(eid, block.bytes) - Block(sampledBlock) - }.collect) - } - // Find range boundaries parceled out to a single worker - val boundaries = time("non-oblivious sort - FindRangeBounds") { - childRDD.context.parallelize(Array(sampled.bytes), 1).map { sampledBytes => - val (enclave, eid) = Utils.initEnclave() - enclave.FindRangeBounds(eid, orderSer, numPartitions, sampledBytes) - }.collect.head - } - // Broadcast the range boundaries and use them to partition the input - childRDD.flatMap { block => - val (enclave, eid) = Utils.initEnclave() - val partitions = enclave.PartitionForSort( - eid, orderSer, numPartitions, block.bytes, boundaries) - partitions.zipWithIndex.map { - case (partition, i) => (i, Block(partition)) - } - } - // Shuffle the input to achieve range partitioning and sort locally - .groupByKey(numPartitions).map { - case (i, blocks) => - val (enclave, eid) = Utils.initEnclave() - Block(enclave.ExternalSort( - eid, orderSer, Utils.concatEncryptedBlocks(blocks.toSeq).bytes)) - } + val numPartitions = childRDD.partitions.length + if (numPartitions <= 1) { + childRDD + } else { + // Collect a sample of the input rows + val sampled = time("non-oblivious sort - Sample") { + Utils.concatEncryptedBlocks(childRDD.map { block => + val (enclave, eid) = Utils.initEnclave() + val sampledBlock = enclave.Sample(eid, block.bytes) + Block(sampledBlock) + }.collect) + } + // Find range boundaries parceled out to a single worker + val boundaries = time("non-oblivious sort - FindRangeBounds") { + childRDD.context.parallelize(Array(sampled.bytes), 1).map { sampledBytes => + val (enclave, eid) = Utils.initEnclave() + enclave.FindRangeBounds(eid, orderSer, numPartitions, sampledBytes) + }.collect.head + } + // Broadcast the range boundaries and use them to partition the input + // Shuffle the input to achieve range partitioning and sort locally + val result = childRDD.flatMap { block => + val (enclave, eid) = Utils.initEnclave() + val partitions = enclave.PartitionForSort( + eid, orderSer, numPartitions, block.bytes, boundaries) + partitions.zipWithIndex.map { + case (partition, i) => (i, Block(partition)) } - Utils.ensureCached(result) - result.count() + }.groupByKey(numPartitions).map { + case (i, blocks) => + Utils.concatEncryptedBlocks(blocks.toSeq) + } result } } + + def localSort(childRDD: RDD[Block], orderSer: Array[Byte]): RDD[Block] = { + Utils.ensureCached(childRDD) + val result = childRDD.map { block => + val (enclave, eid) = Utils.initEnclave() + val sortedRows = enclave.ExternalSort(eid, orderSer, block.bytes) + Block(sortedRows) + } + result + } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala index aef4ba8303..b49090ced1 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala @@ -39,10 +39,8 @@ class SGXEnclave extends java.io.Serializable { boundaries: Array[Byte]): Array[Array[Byte]] @native def ExternalSort(eid: Long, order: Array[Byte], input: Array[Byte]): Array[Byte] - @native def ScanCollectLastPrimary( - eid: Long, joinExpr: Array[Byte], input: Array[Byte]): Array[Byte] @native def NonObliviousSortMergeJoin( - eid: Long, joinExpr: Array[Byte], input: Array[Byte], joinRow: Array[Byte]): Array[Byte] + eid: Long, joinExpr: Array[Byte], input: Array[Byte]): Array[Byte] @native def NonObliviousAggregate( eid: Long, aggOp: Array[Byte], inputRows: Array[Byte], isPartial: Boolean): (Array[Byte]) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index e40acbff78..7ed6862b6b 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -26,7 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.LeftAnti +import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan @@ -274,9 +277,15 @@ case class EncryptedSortMergeJoinExec( rightKeys: Seq[Expression], leftSchema: Seq[Attribute], rightSchema: Seq[Attribute], - output: Seq[Attribute], child: SparkPlan) - extends UnaryExecNode with OpaqueOperatorExec { + extends UnaryExecNode with OpaqueOperatorExec { + + override def output: Seq[Attribute] = { + joinType match { + case Inner => (leftSchema ++ rightSchema).map(_.toAttribute) + case LeftSemi | LeftAnti => leftSchema.map(_.toAttribute) + } + } override def executeBlocked(): RDD[Block] = { val joinExprSer = Utils.serializeJoinExpression( @@ -286,22 +295,9 @@ case class EncryptedSortMergeJoinExec( child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedSortMergeJoinExec") { childRDD => - val lastPrimaryRows = childRDD.map { block => + childRDD.map { block => val (enclave, eid) = Utils.initEnclave() - Block(enclave.ScanCollectLastPrimary(eid, joinExprSer, block.bytes)) - }.collect - val shifted = Utils.emptyBlock +: lastPrimaryRows.dropRight(1) - assert(shifted.size == childRDD.partitions.length) - val processedJoinRowsRDD = - sparkContext.parallelize(shifted, childRDD.partitions.length) - - childRDD.zipPartitions(processedJoinRowsRDD) { (blockIter, joinRowIter) => - (blockIter.toSeq, joinRowIter.toSeq) match { - case (Seq(block), Seq(joinRow)) => - val (enclave, eid) = Utils.initEnclave() - Iterator(Block(enclave.NonObliviousSortMergeJoin( - eid, joinExprSer, block.bytes, joinRow.bytes))) - } + Block(enclave.NonObliviousSortMergeJoin(eid, joinExprSer, block.bytes)) } } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index f26551553d..0c8f188369 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -32,6 +32,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.LeftAnti +import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.execution.SparkPlan import edu.berkeley.cs.rise.opaque.execution._ @@ -76,20 +79,30 @@ object OpaqueOperators extends Strategy { val leftProj = EncryptedProjectExec(leftProjSchema, planLater(left)) val rightProj = EncryptedProjectExec(rightProjSchema, planLater(right)) val unioned = EncryptedUnionExec(leftProj, rightProj) - val sorted = EncryptedSortExec(sortForJoin(leftKeysProj, tag, unioned.output), true, unioned) + // We partition based on the join keys only, so that rows from both the left and the right tables that match + // will colocate to the same partition + val partitionOrder = leftKeysProj.map(k => SortOrder(k, Ascending)) + val partitioned = EncryptedRangePartitionExec(partitionOrder, unioned) + val sortOrder = sortForJoin(leftKeysProj, tag, partitioned.output) + val sorted = EncryptedSortExec(sortOrder, false, partitioned) val joined = EncryptedSortMergeJoinExec( joinType, leftKeysProj, rightKeysProj, leftProjSchema.map(_.toAttribute), rightProjSchema.map(_.toAttribute), - (leftProjSchema ++ rightProjSchema).map(_.toAttribute), sorted) - val tagsDropped = EncryptedProjectExec(dropTags(left.output, right.output), joined) + + val tagsDropped = joinType match { + case Inner => EncryptedProjectExec(dropTags(left.output, right.output), joined) + case LeftSemi | LeftAnti => EncryptedProjectExec(left.output, joined) + } + val filtered = condition match { case Some(condition) => EncryptedFilterExec(condition, tagsDropped) case None => tagsDropped } + filtered :: Nil case a @ PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index c8926c3df7..16c8082fbd 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -305,7 +305,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val f_data = for (i <- 1 to 256 - 16) yield ((i % 16).toString, (i * 10).toString, i.toFloat) val p = makeDF(p_data, securityLevel, "pk", "x") val f = makeDF(f_data, securityLevel, "fk", "x", "y") - p.join(f, $"pk" === $"fk").collect.toSet + val df = p.join(f, $"pk" === $"fk").collect.toSet } testAgainstSpark("non-foreign-key join") { securityLevel => @@ -316,6 +316,33 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => p.join(f, $"join_col_1" === $"join_col_2").collect.toSet } + testAgainstSpark("left semi join") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) + val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") + val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1", $"id1") + df.collect + } + + testAgainstSpark("left anti join 1") { securityLevel => + val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) + val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" === $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + + testAgainstSpark("left anti join 2") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" === $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + def abc(i: Int): String = (i % 3) match { case 0 => "A" case 1 => "B"