diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala index 3d64eb4170..dc1cbae97b 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/JobVerificationEngine.scala @@ -21,6 +21,10 @@ package edu.berkeley.cs.rise.opaque import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Map import scala.collection.mutable.Set +import scala.collection.mutable.Queue + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.execution.SparkPlan // Wraps Crumb data specific to graph vertices and adds graph methods. class JobNode(val inputMacs: ArrayBuffer[ArrayBuffer[Byte]] = ArrayBuffer[ArrayBuffer[Byte]](), @@ -80,6 +84,11 @@ class JobNode(val inputMacs: ArrayBuffer[ArrayBuffer[Byte]] = ArrayBuffer[ArrayB return retval } + // Returns if this DAG is empty + def graphIsEmpty(): Boolean = { + return this.isSource && this.outgoingNeighbors.isEmpty + } + // Checks if JobNodeData originates from same partition (?) override def equals(that: Any): Boolean = { that match { @@ -96,24 +105,40 @@ class JobNode(val inputMacs: ArrayBuffer[ArrayBuffer[Byte]] = ArrayBuffer[ArrayB override def hashCode(): Int = { inputMacs.hashCode ^ allOutputsMac.hashCode } +} - def printNode() = { - println("====") - print("Ecall: ") - println(this.ecall) - print("Output: ") - for (i <- 0 until this.allOutputsMac.length) { - print(this.allOutputsMac(i)) - } - println - println("===") +// Used in construction of expected DAG. +class OperatorNode(val operatorName: String = "") { + var children: ArrayBuffer[OperatorNode] = ArrayBuffer[OperatorNode]() + var parents: ArrayBuffer[OperatorNode] = ArrayBuffer[OperatorNode]() + // Contains numPartitions * numEcalls job nodes. + // numPartitions rows (outer array), numEcalls columns (inner array) + var jobNodes: ArrayBuffer[ArrayBuffer[JobNode]] = ArrayBuffer[ArrayBuffer[JobNode]]() + + def addChild(child: OperatorNode) = { + this.children.append(child) + } + + def setChildren(children: ArrayBuffer[OperatorNode]) = { + this.children = children + } + + def addParent(parent: OperatorNode) = { + this.parents.append(parent) + } + + def setParents(parents: ArrayBuffer[OperatorNode]) = { + this.parents = parents + } + + def isOrphan(): Boolean = { + return this.parents.isEmpty } } object JobVerificationEngine { // An LogEntryChain object from each partition var logEntryChains = ArrayBuffer[tuix.LogEntryChain]() - var sparkOperators = ArrayBuffer[String]() val ecallId = Map( 1 -> "project", 2 -> "filter", @@ -130,31 +155,299 @@ object JobVerificationEngine { 13 -> "limitReturnRows" ).withDefaultValue("unknown") - def pathsEqual(path1: ArrayBuffer[List[Seq[Int]]], - path2: ArrayBuffer[List[Seq[Int]]]): Boolean = { - return path1.size == path2.size && path1.toSet == path2.toSet - } + val possibleSparkOperators = Seq[String]("EncryptedProject", + "EncryptedSortMergeJoin", + "EncryptedSort", + "EncryptedFilter", + "EncryptedAggregate", + "EncryptedGlobalLimit", + "EncryptedLocalLimit") def addLogEntryChain(logEntryChain: tuix.LogEntryChain): Unit = { logEntryChains += logEntryChain } - def addExpectedOperator(operator: String): Unit = { - sparkOperators += operator - } - def resetForNextJob(): Unit = { - sparkOperators.clear logEntryChains.clear } - def verify(): Boolean = { - if (sparkOperators.isEmpty) { + def isValidOperatorNode(node: OperatorNode): Boolean = { + for (targetSubstring <- possibleSparkOperators) { + if (node.operatorName contains targetSubstring) { + return true + } + } + return false + } + + def pathsEqual(executedPaths: ArrayBuffer[List[Seq[Int]]], + expectedPaths: ArrayBuffer[List[Seq[Int]]]): Boolean = { + // Executed paths might contain extraneous paths from + // MACs matching across ecalls if a block is unchanged from ecall to ecall (?) + return expectedPaths.toSet.subsetOf(executedPaths.toSet) + } + + // Recursively convert SparkPlan objects to OperatorNode object. + def sparkNodesToOperatorNodes(plan: SparkPlan): OperatorNode = { + var operatorName = "" + for (sparkOperator <- possibleSparkOperators) { + if (plan.toString.split("\n")(0) contains sparkOperator) { + operatorName = sparkOperator + } + } + val operatorNode = new OperatorNode(operatorName) + for (child <- plan.children) { + val parentOperatorNode = sparkNodesToOperatorNodes(child) + operatorNode.addParent(parentOperatorNode) + } + return operatorNode + } + + // Returns true if every OperatorNode in this list is "valid". + def allValidOperators(operators: ArrayBuffer[OperatorNode]): Boolean = { + for (operator <- operators) { + if (!isValidOperatorNode(operator)) { + return false + } + } + return true + } + + // Recursively prunes non valid nodes from an OperatorNode tree. + def fixOperatorTree(root: OperatorNode): Unit = { + if (root.isOrphan) { + return + } + while (!allValidOperators(root.parents)) { + val newParents = new ArrayBuffer[OperatorNode]() + for (parent <- root.parents) { + if (isValidOperatorNode(parent)) { + newParents.append(parent) + } else { + for (grandparent <- parent.parents) { + newParents.append(grandparent) + } + } + } + root.setParents(newParents) + } + for (parent <- root.parents) { + parent.addChild(root) + fixOperatorTree(parent) + } + } + + // Uses BFS to put all nodes in an OperatorNode tree into a list. + def treeToList(root: OperatorNode): ArrayBuffer[OperatorNode] = { + val retval = ArrayBuffer[OperatorNode]() + val queue = new Queue[OperatorNode]() + queue.enqueue(root) + while (!queue.isEmpty) { + val curr = queue.dequeue + retval.append(curr) + for (parent <- curr.parents) { + queue.enqueue(parent) + } + } + return retval + } + + // Converts a SparkPlan into a DAG of OperatorNode objects. + // Returns a list of all the nodes in the DAG. + def operatorDAGFromPlan(executedPlan: SparkPlan): ArrayBuffer[OperatorNode] = { + // Convert SparkPlan tree to OperatorNode tree + val leafOperatorNode = sparkNodesToOperatorNodes(executedPlan) + // Enlist the tree + val allOperatorNodes = treeToList(leafOperatorNode) + // Attach a sink to the tree and prune invalid OperatorNodes starting from the sink. + val sinkNode = new OperatorNode("sink") + for (operatorNode <- allOperatorNodes) { + if (operatorNode.children.isEmpty) { + operatorNode.addChild(sinkNode) + } + } + fixOperatorTree(sinkNode) + // Enlist the fixed tree. + val fixedOperatorNodes = treeToList(sinkNode) + fixedOperatorNodes -= sinkNode + return fixedOperatorNodes + } + + // expectedDAGFromOperatorDAG helper - links parent ecall partitions to child ecall partitions. + def linkEcalls(parentEcalls: ArrayBuffer[JobNode], childEcalls: ArrayBuffer[JobNode]): Unit = { + if (parentEcalls.length != childEcalls.length) { + println("Ecall lengths don't match! (linkEcalls)") + } + val numPartitions = parentEcalls.length + val ecall = parentEcalls(0).ecall + // project + if (ecall == 1) { + for (i <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(i)) + } + // filter + } else if (ecall == 2) { + for (i <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(i)) + } + // externalSort + } else if (ecall == 6) { + for (i <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(i)) + } + // sample + } else if (ecall == 3) { + for (i <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(0)) + } + // findRangeBounds + } else if (ecall == 4) { + for (i <- 0 until numPartitions) { + parentEcalls(0).addOutgoingNeighbor(childEcalls(i)) + } + // partitionForSort + } else if (ecall == 5) { + // All to all shuffle + for (i <- 0 until numPartitions) { + for (j <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(j)) + } + } + // nonObliviousAggregate + } else if (ecall == 9) { + for (i <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(i)) + } + // nonObliviousSortMergeJoin + } else if (ecall == 8) { + for (i <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(i)) + } + // countRowsPerPartition + } else if (ecall == 10) { + // Send from all partitions to partition 0 + for (i <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(0)) + } + // computeNumRowsPerPartition + } else if (ecall == 11) { + // Broadcast from one partition (assumed to be partition 0) to all partitions + for (i <- 0 until numPartitions) { + parentEcalls(0).addOutgoingNeighbor(childEcalls(i)) + } + // limitReturnRows + } else if (ecall == 13) { + for (i <- 0 until numPartitions) { + parentEcalls(i).addOutgoingNeighbor(childEcalls(i)) + } + } else { + throw new Exception("Job Verification Error creating expected DAG: " + + "ecall not supported - " + ecall) + } + } + + // expectedDAGFromOperatorDAG helper - generates a matrix of job nodes for each operator node. + def generateJobNodes(numPartitions: Int, operatorName: String): ArrayBuffer[ArrayBuffer[JobNode]] = { + val jobNodes = ArrayBuffer[ArrayBuffer[JobNode]]() + val expectedEcalls = ArrayBuffer[Int]() + if (operatorName == "EncryptedSort" && numPartitions == 1) { + // ("externalSort") + expectedEcalls.append(6) + } else if (operatorName == "EncryptedSort" && numPartitions > 1) { + // ("sample", "findRangeBounds", "partitionForSort", "externalSort") + expectedEcalls.append(3, 4, 5, 6) + } else if (operatorName == "EncryptedProject") { + // ("project") + expectedEcalls.append(1) + } else if (operatorName == "EncryptedFilter") { + // ("filter") + expectedEcalls.append(2) + } else if (operatorName == "EncryptedAggregate") { + // ("nonObliviousAggregate") + expectedEcalls.append(9) + } else if (operatorName == "EncryptedSortMergeJoin") { + // ("nonObliviousSortMergeJoin") + expectedEcalls.append(8) + } else if (operatorName == "EncryptedLocalLimit") { + // ("limitReturnRows") + expectedEcalls.append(13) + } else if (operatorName == "EncryptedGlobalLimit") { + // ("countRowsPerPartition", "computeNumRowsPerPartition", "limitReturnRows") + expectedEcalls.append(10, 11, 13) + } else { + throw new Exception("Executed unknown operator: " + operatorName) + } + for (ecallIdx <- 0 until expectedEcalls.length) { + val ecall = expectedEcalls(ecallIdx) + val ecallJobNodes = ArrayBuffer[JobNode]() + jobNodes.append(ecallJobNodes) + for (partitionIdx <- 0 until numPartitions) { + val jobNode = new JobNode() + jobNode.setEcall(ecall) + ecallJobNodes.append(jobNode) + } + } + return jobNodes + } + + // Converts a DAG of Spark operators to a DAG of ecalls and partitions. + def expectedDAGFromOperatorDAG(operatorNodes: ArrayBuffer[OperatorNode]): JobNode = { + val source = new JobNode() + val sink = new JobNode() + source.setSource + sink.setSink + // For each node, create numPartitions * numEcalls jobnodes. + for (node <- operatorNodes) { + node.jobNodes = generateJobNodes(logEntryChains.size, node.operatorName) + } + // Link all ecalls. + for (node <- operatorNodes) { + for (ecallIdx <- 0 until node.jobNodes.length) { + if (ecallIdx == node.jobNodes.length - 1) { + // last ecall of this operator, link to child operators if one exists. + for (child <- node.children) { + linkEcalls(node.jobNodes(ecallIdx), child.jobNodes(0)) + } + } else { + linkEcalls(node.jobNodes(ecallIdx), node.jobNodes(ecallIdx + 1)) + } + } + } + // Set source and sink + for (node <- operatorNodes) { + if (node.isOrphan) { + for (jobNode <- node.jobNodes(0)) { + source.addOutgoingNeighbor(jobNode) + } + } + if (node.children.isEmpty) { + for (jobNode <- node.jobNodes(node.jobNodes.length - 1)) { + jobNode.addOutgoingNeighbor(sink) + } + } + } + return source + } + + // Generates an expected DAG of ecalls and partitions from a dataframe's SparkPlan object. + def expectedDAGFromPlan(executedPlan: SparkPlan): JobNode = { + val operatorDAGRoot = operatorDAGFromPlan(executedPlan) + expectedDAGFromOperatorDAG(operatorDAGRoot) + } + + // Verify that the executed flow of information from ecall partition to ecall partition + // matches what is expected for a given Spark dataframe. + def verify(df: DataFrame): Boolean = { + // Get expected DAG. + val expectedSourceNode = expectedDAGFromPlan(df.queryExecution.executedPlan) + + // Quit if graph is empty. + if (expectedSourceNode.graphIsEmpty) { return true } - val OE_HMAC_SIZE = 32 - val numPartitions = logEntryChains.size + // Construct executed DAG. + val OE_HMAC_SIZE = 32 // Keep a set of nodes, since right now, the last nodes won't have outputs. val nodeSet = Set[JobNode]() // Set up map from allOutputsMAC --> JobNode. @@ -237,157 +530,14 @@ object JobVerificationEngine { } } - // ========================================== // - - // Construct expected DAG. - val expectedDAG = ArrayBuffer[ArrayBuffer[JobNode]]() - val expectedEcalls = ArrayBuffer[Int]() - for (operator <- sparkOperators) { - if (operator == "EncryptedSortExec" && numPartitions == 1) { - // ("externalSort") - expectedEcalls.append(6) - } else if (operator == "EncryptedSortExec" && numPartitions > 1) { - // ("sample", "findRangeBounds", "partitionForSort", "externalSort") - expectedEcalls.append(3, 4, 5, 6) - } else if (operator == "EncryptedProjectExec") { - // ("project") - expectedEcalls.append(1) - } else if (operator == "EncryptedFilterExec") { - // ("filter") - expectedEcalls.append(2) - } else if (operator == "EncryptedAggregateExec") { - // ("nonObliviousAggregate") - expectedEcalls.append(9) - } else if (operator == "EncryptedSortMergeJoinExec") { - // ("nonObliviousSortMergeJoin") - expectedEcalls.append(8) - } else if (operator == "EncryptedLocalLimitExec") { - // ("limitReturnRows") - expectedEcalls.append(13) - } else if (operator == "EncryptedGlobalLimitExec") { - // ("countRowsPerPartition", "computeNumRowsPerPartition", "limitReturnRows") - expectedEcalls.append(10, 11, 13) - } else { - throw new Exception("Executed unknown operator") - } - } - - // Initialize job nodes. - val expectedSourceNode = new JobNode() - expectedSourceNode.setSource - val expectedSinkNode = new JobNode() - expectedSinkNode.setSink - for (j <- 0 until numPartitions) { - val partitionJobNodes = ArrayBuffer[JobNode]() - expectedDAG.append(partitionJobNodes) - for (i <- 0 until expectedEcalls.length) { - val ecall = expectedEcalls(i) - val jobNode = new JobNode() - jobNode.setEcall(ecall) - partitionJobNodes.append(jobNode) - // Connect source node to starting ecall partitions. - if (i == 0) { - expectedSourceNode.addOutgoingNeighbor(jobNode) - } - // Connect ending ecall partitions to sink. - if (i == expectedEcalls.length - 1) { - jobNode.addOutgoingNeighbor(expectedSinkNode) - } - } - } - - // Set outgoing neighbors for all nodes, except for the ones in the last ecall. - for (i <- 0 until expectedEcalls.length - 1) { - // i represents the current ecall index - val operator = expectedEcalls(i) - // project - if (operator == 1) { - for (j <- 0 until numPartitions) { - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(j)(i + 1)) - } - // filter - } else if (operator == 2) { - for (j <- 0 until numPartitions) { - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(j)(i + 1)) - } - // externalSort - } else if (operator == 6) { - for (j <- 0 until numPartitions) { - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(j)(i + 1)) - } - // sample - } else if (operator == 3) { - for (j <- 0 until numPartitions) { - // All EncryptedBlocks resulting from sample go to one worker - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(0)(i + 1)) - } - // findRangeBounds - } else if (operator == 4) { - // Broadcast from one partition (assumed to be partition 0) to all partitions - for (j <- 0 until numPartitions) { - expectedDAG(0)(i).addOutgoingNeighbor(expectedDAG(j)(i + 1)) - } - // partitionForSort - } else if (operator == 5) { - // All to all shuffle - for (j <- 0 until numPartitions) { - for (k <- 0 until numPartitions) { - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(k)(i + 1)) - } - } - // nonObliviousAggregate - } else if (operator == 9) { - for (j <- 0 until numPartitions) { - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(j)(i + 1)) - } - // scanCollectLastPrimary - } else if (operator == 7) { - // Blocks sent to next partition - if (numPartitions == 1) { - expectedDAG(0)(i).addOutgoingNeighbor(expectedDAG(0)(i + 1)) - } else { - for (j <- 0 until numPartitions) { - if (j < numPartitions - 1) { - val next = j + 1 - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(next)(i + 1)) - } - } - } - // nonObliviousSortMergeJoin - } else if (operator == 8) { - for (j <- 0 until numPartitions) { - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(j)(i + 1)) - } - // countRowsPerPartition - } else if (operator == 10) { - // Send from all partitions to partition 0 - for (j <- 0 until numPartitions) { - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(0)(i + 1)) - } - // computeNumRowsPerPartition - } else if (operator == 11) { - // Broadcast from one partition (assumed to be partition 0) to all partitions - for (j <- 0 until numPartitions) { - expectedDAG(0)(i).addOutgoingNeighbor(expectedDAG(j)(i + 1)) - } - // limitReturnRows - } else if (operator == 13) { - for (j <- 0 until numPartitions) { - expectedDAG(j)(i).addOutgoingNeighbor(expectedDAG(j)(i + 1)) - } - } else { - throw new Exception("Job Verification Error creating expected DAG: " - + "operator not supported - " + operator) - } - } val executedPathsToSink = executedSourceNode.pathsToSink val expectedPathsToSink = expectedSourceNode.pathsToSink val arePathsEqual = pathsEqual(executedPathsToSink, expectedPathsToSink) if (!arePathsEqual) { - println(executedPathsToSink.toString) - println(expectedPathsToSink.toString) + // println(executedPathsToSink.toString) + // println(expectedPathsToSink.toString) println("===========DAGS NOT EQUAL===========") } - return true + return true } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 815bf0e738..fd0796ac5b 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -33,6 +33,7 @@ import scala.collection.mutable.ArrayBuilder import com.google.flatbuffers.FlatBufferBuilder import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Dataset import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -73,7 +74,6 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.expressions.StartsWith import org.apache.spark.sql.catalyst.expressions.Substring import org.apache.spark.sql.catalyst.expressions.Subtract -import org.apache.spark.sql.catalyst.expressions.TimeAdd import org.apache.spark.sql.catalyst.expressions.UnaryMinus import org.apache.spark.sql.catalyst.expressions.Upper import org.apache.spark.sql.catalyst.expressions.Year @@ -800,8 +800,8 @@ object Utils extends Logging { JobVerificationEngine.addLogEntryChain(blockLog) } - def verifyJob(): Boolean = { - return JobVerificationEngine.verify() + def verifyJob(df: DataFrame): Boolean = { + return JobVerificationEngine.verify(df) } def treeFold[BaseType <: TreeNode[BaseType], B]( diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala index 1dce88ed1a..a32e7c10e8 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/EncryptedSortExec.scala @@ -18,7 +18,6 @@ package edu.berkeley.cs.rise.opaque.execution import edu.berkeley.cs.rise.opaque.Utils -import edu.berkeley.cs.rise.opaque.JobVerificationEngine import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.SortOrder @@ -32,7 +31,6 @@ case class EncryptedSortExec(order: Seq[SortOrder], isGlobal: Boolean, child: Sp override def executeBlocked(): RDD[Block] = { val orderSer = Utils.serializeSortOrder(order, child.output) val childRDD = child.asInstanceOf[OpaqueOperatorExec].executeBlocked() - JobVerificationEngine.addExpectedOperator("EncryptedSortExec") val partitionedRDD = isGlobal match { case true => EncryptedSortExec.sampleAndPartition(childRDD, orderSer) case false => childRDD diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index 252d8eb33f..0497b3cf2a 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -20,7 +20,6 @@ package edu.berkeley.cs.rise.opaque.execution import scala.collection.mutable.ArrayBuffer import edu.berkeley.cs.rise.opaque.Utils -import edu.berkeley.cs.rise.opaque.JobVerificationEngine import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.AttributeSet @@ -147,15 +146,9 @@ trait OpaqueOperatorExec extends SparkPlan { collectedRDD.map { block => Utils.addBlockForVerification(block) } - - val postVerificationPasses = Utils.verifyJob() - JobVerificationEngine.resetForNextJob() - if (postVerificationPasses) { - collectedRDD.flatMap { block => - Utils.decryptBlockFlatbuffers(block) - } - } else { - throw new Exception("Post Verification Failed") + + collectedRDD.flatMap { block => + Utils.decryptBlockFlatbuffers(block) } } @@ -218,7 +211,6 @@ case class EncryptedProjectExec(projectList: Seq[NamedExpression], child: SparkP val projectListSer = Utils.serializeProjectList(projectList, child.output) timeOperator(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedProjectExec") { childRDD => - JobVerificationEngine.addExpectedOperator("EncryptedProjectExec") childRDD.map { block => val (enclave, eid) = Utils.initEnclave() Block(enclave.Project(eid, projectListSer, block.bytes)) @@ -237,7 +229,6 @@ case class EncryptedFilterExec(condition: Expression, child: SparkPlan) val conditionSer = Utils.serializeFilterExpression(condition, child.output) timeOperator(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedFilterExec") { childRDD => - JobVerificationEngine.addExpectedOperator("EncryptedFilterExec") childRDD.map { block => val (enclave, eid) = Utils.initEnclave() Block(enclave.Filter(eid, conditionSer, block.bytes)) @@ -283,7 +274,6 @@ case class EncryptedAggregateExec( timeOperator(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedPartialAggregateExec") { childRDD => - JobVerificationEngine.addExpectedOperator("EncryptedAggregateExec") childRDD.map { block => val (enclave, eid) = Utils.initEnclave() Block(enclave.NonObliviousAggregate(eid, aggExprSer, block.bytes, (mode == Partial))) @@ -316,7 +306,6 @@ case class EncryptedSortMergeJoinExec( child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedSortMergeJoinExec") { childRDD => - JobVerificationEngine.addExpectedOperator("EncryptedSortMergeJoinExec") childRDD.map { block => val (enclave, eid) = Utils.initEnclave() Block(enclave.NonObliviousSortMergeJoin(eid, joinExprSer, block.bytes)) @@ -373,7 +362,6 @@ case class EncryptedLocalLimitExec( override def executeBlocked(): RDD[Block] = { timeOperator(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedLocalLimitExec") { childRDD => - JobVerificationEngine.addExpectedOperator("EncryptedLocalLimitExec") childRDD.map { block => val (enclave, eid) = Utils.initEnclave() Block(enclave.LocalLimit(eid, limit, block.bytes)) @@ -394,7 +382,6 @@ case class EncryptedGlobalLimitExec( override def executeBlocked(): RDD[Block] = { timeOperator(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedGlobalLimitExec") { childRDD => - JobVerificationEngine.addExpectedOperator("EncryptedGlobalLimitExec") val numRowsPerPartition = Utils.concatEncryptedBlocks(childRDD.map { block => val (enclave, eid) = Utils.initEnclave() Block(enclave.CountRowsPerPartition(eid, block.bytes)) diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index 26b9d01b7b..0aa55d3138 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -70,12 +70,22 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => import ExtraDFOperations._ + def integrityCollect(df: DataFrame): Seq[Row] = { + JobVerificationEngine.resetForNextJob() + val retval = df.collect + val postVerificationPasses = Utils.verifyJob(df) + if (!postVerificationPasses) { + println("Job Verification Failure") + } + return retval + } + testAgainstSpark("Interval SQL") { securityLevel => val data = Seq(Tuple2(1, new java.sql.Date(new java.util.Date().getTime()))) val df = makeDF(data, securityLevel, "index", "time") df.createTempView("Interval") try { - spark.sql("SELECT time + INTERVAL 7 DAY FROM Interval").collect + integrityCollect(spark.sql("SELECT time + INTERVAL 7 DAY FROM Interval")) } finally { spark.catalog.dropTempView("Interval") } @@ -86,7 +96,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val df = makeDF(data, securityLevel, "index", "time") df.createTempView("Interval") try { - spark.sql("SELECT time + INTERVAL 7 WEEK FROM Interval").collect + integrityCollect(spark.sql("SELECT time + INTERVAL 7 WEEK FROM Interval")) } finally { spark.catalog.dropTempView("Interval") } @@ -97,7 +107,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val df = makeDF(data, securityLevel, "index", "time") df.createTempView("Interval") try { - spark.sql("SELECT time + INTERVAL 6 MONTH FROM Interval").collect + integrityCollect(spark.sql("SELECT time + INTERVAL 6 MONTH FROM Interval")) } finally { spark.catalog.dropTempView("Interval") } @@ -106,18 +116,18 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => testAgainstSpark("Date Add") { securityLevel => val data = Seq(Tuple2(1, new java.sql.Date(new java.util.Date().getTime()))) val df = makeDF(data, securityLevel, "index", "time") - df.select(date_add($"time", 3)).collect + integrityCollect(df.select(date_add($"time", 3))) } testAgainstSpark("create DataFrame from sequence") { securityLevel => val data = for (i <- 0 until 5) yield ("foo", i) - makeDF(data, securityLevel, "word", "count").collect + integrityCollect(makeDF(data, securityLevel, "word", "count")) } testAgainstSpark("create DataFrame with BinaryType + ByteType") { securityLevel => val data: Seq[(Array[Byte], Byte)] = Seq((Array[Byte](0.toByte, -128.toByte, 127.toByte), 42.toByte)) - makeDF(data, securityLevel, "BinaryType", "ByteType").collect + integrityCollect(makeDF(data, securityLevel, "BinaryType", "ByteType")) } testAgainstSpark("create DataFrame with CalendarIntervalType + NullType") { securityLevel => @@ -126,15 +136,15 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => StructField("CalendarIntervalType", CalendarIntervalType), StructField("NullType", NullType))) - securityLevel.applyTo( + integrityCollect(securityLevel.applyTo( spark.createDataFrame( spark.sparkContext.makeRDD(data.map(Row.fromTuple), numPartitions), - schema)).collect + schema))) } testAgainstSpark("create DataFrame with ShortType + TimestampType") { securityLevel => val data: Seq[(Short, Timestamp)] = Seq((13.toShort, Timestamp.valueOf("2017-12-02 03:04:00"))) - makeDF(data, securityLevel, "ShortType", "TimestampType").collect + integrityCollect(makeDF(data, securityLevel, "ShortType", "TimestampType")) } testAgainstSpark("create DataFrame with ArrayType") { securityLevel => @@ -144,7 +154,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => (array, "cat"), (array, "ant")) val df = makeDF(data, securityLevel, "array", "string") - df.collect + integrityCollect(df) } testAgainstSpark("create DataFrame with MapType") { securityLevel => @@ -154,7 +164,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => (map, "cat"), (map, "ant")) val df = makeDF(data, securityLevel, "map", "string") - df.collect + integrityCollect(df) } testAgainstSpark("create DataFrame with nulls for all types") { securityLevel => @@ -175,10 +185,10 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => StructField("map_int_to_int", DataTypes.createMapType(IntegerType, IntegerType)), StructField("string", StringType))) - securityLevel.applyTo( + integrityCollect(securityLevel.applyTo( spark.createDataFrame( spark.sparkContext.makeRDD(Seq(Row.fromSeq(Seq.fill(schema.length) { null })), numPartitions), - schema)).collect + schema))) } testAgainstSpark("filter") { securityLevel => @@ -186,7 +196,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => (1 to 20).map(x => (true, "hello", 1.0, 2.0f, x)), securityLevel, "a", "b", "c", "d", "x") - df.filter($"x" > lit(10)).collect + integrityCollect(df.filter($"x" > lit(10))) } testAgainstSpark("filter with NULLs") { securityLevel => @@ -197,13 +207,13 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => Tuple1(x.asInstanceOf[Integer]) }).toSeq) val df = makeDF(data, securityLevel, "x") - df.filter($"x" > lit(10)).collect.toSet + integrityCollect(df.filter($"x" > lit(10))).toSet } testAgainstSpark("select") { securityLevel => val data = for (i <- 0 until 256) yield ("%03d".format(i) * 3, i.toFloat) val df = makeDF(data, securityLevel, "str", "x") - df.select($"str").collect + integrityCollect(df.select($"str")) } testAgainstSpark("select with expressions") { securityLevel => @@ -211,12 +221,12 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => (1 to 20).map(x => (true, "hello world!", 1.0, 2.0f, x)), securityLevel, "a", "b", "c", "d", "x") - df.select( + integrityCollect(df.select( $"x" + $"x" * $"x" - $"x", substring($"b", 5, 20), $"x" > $"x", $"x" >= $"x", - $"x" <= $"x").collect.toSet + $"x" <= $"x")).toSet } testAgainstSpark("union") { securityLevel => @@ -228,7 +238,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => (1 to 20).map(x => (x, (x + 1).toString)), securityLevel, "a", "b") - df1.union(df2).collect.toSet + integrityCollect(df1.union(df2)).toSet } testOpaqueOnly("cache") { securityLevel => @@ -254,31 +264,31 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => testAgainstSpark("sort") { securityLevel => val data = Random.shuffle((0 until 256).map(x => (x.toString, x)).toSeq) val df = makeDF(data, securityLevel, "str", "x") - df.sort($"x").collect + integrityCollect(df.sort($"x")) } testAgainstSpark("sort zero elements") { securityLevel => val data = Seq.empty[(String, Int)] val df = makeDF(data, securityLevel, "str", "x") - df.sort($"x").collect + integrityCollect(df.sort($"x")) } testAgainstSpark("sort by float") { securityLevel => val data = Random.shuffle((0 until 256).map(x => (x.toString, x.toFloat)).toSeq) val df = makeDF(data, securityLevel, "str", "x") - df.sort($"x").collect + integrityCollect(df.sort($"x")) } testAgainstSpark("sort by string") { securityLevel => val data = Random.shuffle((0 until 256).map(x => (x.toString, x.toFloat)).toSeq) val df = makeDF(data, securityLevel, "str", "x") - df.sort($"str").collect + integrityCollect(df.sort($"str")) } testAgainstSpark("sort by 2 columns") { securityLevel => val data = Random.shuffle((0 until 256).map(x => (x / 16, x)).toSeq) val df = makeDF(data, securityLevel, "x", "y") - df.sort($"x", $"y").collect + integrityCollect(df.sort($"x", $"y")) } testAgainstSpark("sort with null values") { securityLevel => @@ -289,7 +299,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => Tuple1(x.asInstanceOf[Integer]) }).toSeq) val df = makeDF(data, securityLevel, "x") - df.sort($"x").collect + integrityCollect(df.sort($"x")) } testAgainstSpark("join") { securityLevel => @@ -297,7 +307,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val f_data = for (i <- 1 to 256 - 16) yield (i, (i % 16).toString, i * 10) val p = makeDF(p_data, securityLevel, "id", "pk", "x") val f = makeDF(f_data, securityLevel, "id", "fk", "x") - p.join(f, $"pk" === $"fk").collect.toSet + integrityCollect(p.join(f, $"pk" === $"fk")).toSet } testAgainstSpark("join on column 1") { securityLevel => @@ -305,7 +315,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val f_data = for (i <- 1 to 256 - 16) yield ((i % 16).toString, (i * 10).toString, i.toFloat) val p = makeDF(p_data, securityLevel, "pk", "x") val f = makeDF(f_data, securityLevel, "fk", "x", "y") - val df = p.join(f, $"pk" === $"fk").collect.toSet + integrityCollect(p.join(f, $"pk" === $"fk")).toSet } testAgainstSpark("non-foreign-key join") { securityLevel => @@ -313,7 +323,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val f_data = for (i <- 1 to 256 - 128) yield (i, (i % 16).toString, i * 10) val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") - p.join(f, $"join_col_1" === $"join_col_2").collect.toSet + integrityCollect(p.join(f, $"join_col_1" === $"join_col_2")).toSet } testAgainstSpark("left semi join") { securityLevel => @@ -322,7 +332,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") val df = p.join(f, $"join_col_1" === $"join_col_2", "left_semi").sort($"join_col_1", $"id1") - df.collect + integrityCollect(df) } testAgainstSpark("left anti join 1") { securityLevel => @@ -331,7 +341,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") val df = p.join(f, $"join_col_1" === $"join_col_2", "left_anti").sort($"join_col_1", $"id") - df.collect + integrityCollect(df) } testAgainstSpark("left anti join 2") { securityLevel => @@ -340,7 +350,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") val df = p.join(f, $"join_col_1" === $"join_col_2", "left_anti").sort($"join_col_1", $"id") - df.collect + integrityCollect(df) } def abc(i: Int): String = (i % 3) match { @@ -360,7 +370,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => words.setNullableStateOfColumn("price", true) val df = words.groupBy("category").agg(avg("price").as("avgPrice")) - df.collect.sortBy { case Row(category: String, _) => category } + integrityCollect(df).sortBy { case Row(category: String, _) => category } } testAgainstSpark("aggregate count") { securityLevel => @@ -372,40 +382,40 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => }.toSeq val words = makeDF(data, securityLevel, "id", "category", "price") words.setNullableStateOfColumn("price", true) - words.groupBy("category").agg(count("category").as("itemsInCategory")) - .collect.sortBy { case Row(category: String, _) => category } + integrityCollect(words.groupBy("category").agg(count("category").as("itemsInCategory"))) + .sortBy { case Row(category: String, _) => category } } testAgainstSpark("aggregate first") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "category", "price") - val df = words.groupBy("category").agg(first("category").as("firstInCategory")) - .collect.sortBy { case Row(category: String, _) => category } + integrityCollect(words.groupBy("category").agg(first("category").as("firstInCategory"))) + .sortBy { case Row(category: String, _) => category } } testAgainstSpark("aggregate last") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "category", "price") - words.groupBy("category").agg(last("category").as("lastInCategory")) - .collect.sortBy { case Row(category: String, _) => category } + integrityCollect(words.groupBy("category").agg(last("category").as("lastInCategory"))) + .sortBy { case Row(category: String, _) => category } } testAgainstSpark("aggregate max") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "category", "price") - words.groupBy("category").agg(max("price").as("maxPrice")) - .collect.sortBy { case Row(category: String, _) => category } + integrityCollect(words.groupBy("category").agg(max("price").as("maxPrice"))) + .sortBy { case Row(category: String, _) => category } } testAgainstSpark("aggregate min") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "category", "price") - words.groupBy("category").agg(min("price").as("minPrice")) - .collect.sortBy { case Row(category: String, _) => category } + integrityCollect(words.groupBy("category").agg(min("price").as("minPrice"))) + .sortBy { case Row(category: String, _) => category } } testAgainstSpark("aggregate sum") { securityLevel => @@ -419,16 +429,16 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val words = makeDF(data, securityLevel, "id", "word", "count") words.setNullableStateOfColumn("count", true) - words.groupBy("word").agg(sum("count").as("totalCount")) - .collect.sortBy { case Row(word: String, _) => word } + integrityCollect(words.groupBy("word").agg(sum("count").as("totalCount"))) + .sortBy { case Row(word: String, _) => word } } testAgainstSpark("aggregate on multiple columns") { securityLevel => val data = for (i <- 0 until 256) yield (abc(i), 1, 1.0f) val words = makeDF(data, securityLevel, "str", "x", "y") - words.groupBy("str").agg(sum("y").as("totalY"), avg("x").as("avgX")) - .collect.sortBy { case Row(str: String, _, _) => str } + integrityCollect(words.groupBy("str").agg(sum("y").as("totalY"), avg("x").as("avgX"))) + .sortBy { case Row(str: String, _, _) => str } } testAgainstSpark("skewed aggregate sum") { securityLevel => @@ -437,34 +447,82 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => }).toSeq) val words = makeDF(data, securityLevel, "id", "word", "count") - words.groupBy("word").agg(sum("count").as("totalCount")) - .collect.sortBy { case Row(word: String, _) => word } + integrityCollect(words.groupBy("word").agg(sum("count").as("totalCount"))) + .sortBy { case Row(word: String, _) => word } } testAgainstSpark("grouping aggregate with 0 rows") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "word", "count") - words.filter($"id" < lit(0)).groupBy("word").agg(sum("count")) - .collect.sortBy { case Row(word: String, _) => word } + integrityCollect(words.filter($"id" < lit(0)).groupBy("word").agg(sum("count"))) + .sortBy { case Row(word: String, _) => word } } testAgainstSpark("global aggregate") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "word", "count") - words.agg(sum("count").as("totalCount")).collect + integrityCollect(words.agg(sum("count").as("totalCount"))) } testAgainstSpark("global aggregate with 0 rows") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "word", "count") val result = words.filter($"id" < lit(0)).agg(count("*")).as("totalCount") - result.collect + integrityCollect(result) } testAgainstSpark("contains") { securityLevel => val data = for (i <- 0 until 256) yield(i.toString, abc(i)) val df = makeDF(data, securityLevel, "word", "abc") - df.filter($"word".contains(lit("1"))).collect + integrityCollect(df.filter($"word".contains(lit("1")))) + } + + testAgainstSpark("concat with string") { securityLevel => + val data = for (i <- 0 until 256) yield ("%03d".format(i) * 3, i.toString) + val df = makeDF(data, securityLevel, "str", "x") + integrityCollect(df.select(concat(col("str"),lit(","),col("x")))) + } + + testAgainstSpark("concat with other datatype") { securityLevel => + // float causes a formating issue where opaque outputs 1.000000 and spark produces 1.0 so the following line is commented out + // val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, 1.0f) + // you can't serialize date so that's not supported as well + // opaque doesn't support byte + val data = for (i <- 0 until 3) yield ("%03d".format(i) * 3, i, null.asInstanceOf[Int],"") + val df = makeDF(data, securityLevel, "str", "int","null","emptystring") + integrityCollect(df.select(concat(col("str"),lit(","),col("int"),col("null"),col("emptystring")))) + } + + testAgainstSpark("isin1") { securityLevel => + val ids = Seq((1, 2, 2), (2, 3, 1)) + val df = makeDF(ids, securityLevel, "x", "y", "id") + val c = $"id" isin ($"x", $"y") + val result = df.filter(c) + integrityCollect(result) + } + + testAgainstSpark("isin2") { securityLevel => + val ids2 = Seq((1, 1, 1), (2, 2, 2), (3,3,3), (4,4,4)) + val df2 = makeDF(ids2, securityLevel, "x", "y", "id") + val c2 = $"id" isin (1 ,2, 4, 5, 6) + val result = df2.filter(c2) + integrityCollect(result) + } + + testAgainstSpark("isin with string") { securityLevel => + val ids3 = Seq(("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), ("b", "b", "b"), ("c","c","c"), ("d","d","d")) + val df3 = makeDF(ids3, securityLevel, "x", "y", "id") + val c3 = $"id" isin ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" ,"b", "c", "d", "e") + val result = df3.filter(c3) + integrityCollect(result) + } + + testAgainstSpark("isin with null") { securityLevel => + val ids4 = Seq((1, 1, 1), (2, 2, 2), (3,3,null.asInstanceOf[Int]), (4,4,4)) + val df4 = makeDF(ids4, securityLevel, "x", "y", "id") + val c4 = $"id" isin (null.asInstanceOf[Int]) + val result = df4.filter(c4) + integrityCollect(result) } testAgainstSpark("concat with string") { securityLevel => @@ -518,97 +576,97 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => testAgainstSpark("between") { securityLevel => val data = for (i <- 0 until 256) yield(i.toString, i) val df = makeDF(data, securityLevel, "word", "count") - df.filter($"count".between(50, 150)).collect + integrityCollect(df.filter($"count".between(50, 150))) } testAgainstSpark("year") { securityLevel => val data = Seq(Tuple2(1, new java.sql.Date(new java.util.Date().getTime()))) val df = makeDF(data, securityLevel, "id", "date") - df.select(year($"date")).collect + integrityCollect(df.select(year($"date"))) } testAgainstSpark("case when - 1 branch with else (string)") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), ("bear", null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.select(when(df("word") === "foo", "hi").otherwise("bye")).collect + integrityCollect(df.select(when(df("word") === "foo", "hi").otherwise("bye"))) } testAgainstSpark("case when - 1 branch with else (int)") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), ("bear", null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.select(when(df("word") === "foo", 10).otherwise(30)).collect + integrityCollect(df.select(when(df("word") === "foo", 10).otherwise(30))) } testAgainstSpark("case when - 1 branch without else (string)") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), ("bear", null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.select(when(df("word") === "foo", "hi")).collect + integrityCollect(df.select(when(df("word") === "foo", "hi"))) } testAgainstSpark("case when - 1 branch without else (int)") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), ("bear", null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.select(when(df("word") === "foo", 10)).collect + integrityCollect(df.select(when(df("word") === "foo", 10))) } testAgainstSpark("case when - 2 branch with else (string)") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), ("bear", null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.select(when(df("word") === "foo", "hi").when(df("word") === "baz", "hello").otherwise("bye")).collect + integrityCollect(df.select(when(df("word") === "foo", "hi").when(df("word") === "baz", "hello").otherwise("bye"))) } testAgainstSpark("case when - 2 branch with else (int)") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), ("bear", null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.select(when(df("word") === "foo", 10).when(df("word") === "baz", 20).otherwise(30)).collect + integrityCollect(df.select(when(df("word") === "foo", 10).when(df("word") === "baz", 20).otherwise(30))) } testAgainstSpark("case when - 2 branch without else (string)") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), ("bear", null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.select(when(df("word") === "foo", "hi").when(df("word") === "baz", "hello")).collect + integrityCollect(df.select(when(df("word") === "foo", "hi").when(df("word") === "baz", "hello"))) } testAgainstSpark("case when - 2 branch without else (int)") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), ("bear", null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.select(when(df("word") === "foo", 3).when(df("word") === "baz", 2)).collect + integrityCollect(df.select(when(df("word") === "foo", 3).when(df("word") === "baz", 2))) } testAgainstSpark("LIKE - Contains") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), (null.asInstanceOf[String], null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.filter($"word".like("%a%")).collect + integrityCollect(df.filter($"word".like("%a%"))) } testAgainstSpark("LIKE - StartsWith") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), (null.asInstanceOf[String], null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.filter($"word".like("ba%")).collect + integrityCollect(df.filter($"word".like("ba%"))) } testAgainstSpark("LIKE - EndsWith") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), (null.asInstanceOf[String], null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.filter($"word".like("%ar")).collect + integrityCollect(df.filter($"word".like("%ar"))) } testAgainstSpark("LIKE - Empty Pattern") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), (null.asInstanceOf[String], null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.filter($"word".like("")).collect + integrityCollect(df.filter($"word".like(""))) } testAgainstSpark("LIKE - Match All") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), (null.asInstanceOf[String], null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.filter($"word".like("%")).collect + integrityCollect(df.filter($"word".like("%"))) } testAgainstSpark("LIKE - Single Wildcard") { securityLevel => val data = Seq(("foo", 4), ("bar", 1), ("baz", 5), (null.asInstanceOf[String], null.asInstanceOf[Int])) val df = makeDF(data, securityLevel, "word", "count") - df.filter($"word".like("ba_")).collect + integrityCollect(df.filter($"word".like("ba_"))) } testAgainstSpark("LIKE - SQL API") { securityLevel => @@ -616,7 +674,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val df = makeDF(data, securityLevel, "word", "count") df.createTempView("df") try { - spark.sql(""" SELECT word FROM df WHERE word LIKE '_a_' """).collect + integrityCollect(spark.sql(""" SELECT word FROM df WHERE word LIKE '_a_' """)) } finally { spark.catalog.dropTempView("df") } @@ -717,7 +775,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => "a", "b", "c", "d", "x") df.createTempView("df") try { - spark.sql("SELECT * FROM df WHERE x > 10").collect + integrityCollect(spark.sql("SELECT * FROM df WHERE x > 10")) } finally { spark.catalog.dropTempView("df") } @@ -753,8 +811,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => spark.createDataFrame( spark.sparkContext.makeRDD(data.map(Row.fromTuple), numPartitions), schema)) - - df.select(exp($"y")).collect + integrityCollect(df.select(exp($"y"))) } testAgainstSpark("vector multiply") { securityLevel => @@ -769,7 +826,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => spark.sparkContext.makeRDD(data.map(Row.fromTuple), numPartitions), schema)) - df.select(vectormultiply($"v", $"c")).collect + integrityCollect(df.select(vectormultiply($"v", $"c"))) } testAgainstSpark("dot product") { securityLevel => @@ -784,7 +841,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => spark.sparkContext.makeRDD(data.map(Row.fromTuple), numPartitions), schema)) - df.select(dot($"v1", $"v2")).collect + integrityCollect(df.select(dot($"v1", $"v2"))) } testAgainstSpark("upper") { securityLevel => @@ -798,7 +855,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => spark.sparkContext.makeRDD(data.map(Row.fromTuple), numPartitions), schema)) - df.select(upper($"v1")).collect + integrityCollect(df.select(upper($"v1"))) } testAgainstSpark("upper with null") { securityLevel => @@ -806,7 +863,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => val df = makeDF(data, securityLevel, "v1", "v2") - df.select(upper($"v2")).collect + integrityCollect(df.select(upper($"v2"))) } testAgainstSpark("vector sum") { securityLevel => @@ -823,7 +880,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => schema)) val vectorsum = new VectorSum - df.groupBy().agg(vectorsum($"v")).collect + integrityCollect(df.groupBy().agg(vectorsum($"v"))) } testAgainstSpark("create array") { securityLevel => @@ -839,7 +896,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => spark.sparkContext.makeRDD(data.map(Row.fromTuple), numPartitions), schema)) - df.select(array($"x1", $"x2").as("x")).collect + integrityCollect(df.select(array($"x1", $"x2").as("x"))) } testAgainstSpark("limit with fewer returned values") { securityLevel => @@ -851,7 +908,7 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => spark.createDataFrame( spark.sparkContext.makeRDD(data.map(Row.fromTuple), numPartitions), schema)) - df.sort($"id").limit(5).collect + integrityCollect(df.sort($"id").limit(5)) } testAgainstSpark("limit with more returned values") { securityLevel => @@ -863,11 +920,11 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => spark.createDataFrame( spark.sparkContext.makeRDD(data.map(Row.fromTuple), numPartitions), schema)) - df.sort($"id").limit(200).collect + integrityCollect(df.sort($"id").limit(200)) } testAgainstSpark("least squares") { securityLevel => - LeastSquares.query(spark, securityLevel, "tiny", numPartitions).collect + integrityCollect(LeastSquares.query(spark, securityLevel, "tiny", numPartitions)) } testAgainstSpark("logistic regression") { securityLevel => @@ -880,21 +937,20 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => } testAgainstSpark("pagerank") { securityLevel => - PageRank.run(spark, securityLevel, "256", numPartitions).collect.toSet + integrityCollect(PageRank.run(spark, securityLevel, "256", numPartitions)).toSet } - testAgainstSpark("big data 1") { securityLevel => - BigDataBenchmark.q1(spark, securityLevel, "tiny", numPartitions).collect + integrityCollect(BigDataBenchmark.q1(spark, securityLevel, "tiny", numPartitions)) } testAgainstSpark("big data 2") { securityLevel => - BigDataBenchmark.q2(spark, securityLevel, "tiny", numPartitions).collect + integrityCollect(BigDataBenchmark.q2(spark, securityLevel, "tiny", numPartitions)) .map { case Row(a: String, b: Double) => (a, b.toFloat) } .sortBy(_._1) } testAgainstSpark("big data 3") { securityLevel => - BigDataBenchmark.q3(spark, securityLevel, "tiny", numPartitions).collect + integrityCollect(BigDataBenchmark.q3(spark, securityLevel, "tiny", numPartitions)) } def makeDF[A <: Product : scala.reflect.ClassTag : scala.reflect.runtime.universe.TypeTag](