diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 641223a62d..cb054a3d36 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -73,7 +73,6 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.expressions.StartsWith import org.apache.spark.sql.catalyst.expressions.Substring import org.apache.spark.sql.catalyst.expressions.Subtract -import org.apache.spark.sql.catalyst.expressions.TimeAdd import org.apache.spark.sql.catalyst.expressions.UnaryMinus import org.apache.spark.sql.catalyst.expressions.Upper import org.apache.spark.sql.catalyst.expressions.Year @@ -109,6 +108,8 @@ import edu.berkeley.cs.rise.opaque.expressions.VectorMultiply import edu.berkeley.cs.rise.opaque.expressions.VectorSum import edu.berkeley.cs.rise.opaque.logical.ConvertToOpaqueOperators import edu.berkeley.cs.rise.opaque.logical.EncryptLocalRelation +import org.apache.spark.sql.catalyst.expressions.PromotePrecision +import org.apache.spark.sql.catalyst.expressions.CheckOverflow object Utils extends Logging { private val perf: Boolean = System.getenv("SGX_PERF") == "1" @@ -350,8 +351,6 @@ object Utils extends Logging { rdd.foreach(x => {}) } - - def flatbuffersCreateField( builder: FlatBufferBuilder, value: Any, dataType: DataType, isNull: Boolean): Int = { (value, dataType) match { @@ -403,6 +402,18 @@ object Utils extends Logging { tuix.FieldUnion.FloatField, tuix.FloatField.createFloatField(builder, 0), isNull) + case (x: Decimal, DecimalType()) => + tuix.Field.createField( + builder, + tuix.FieldUnion.FloatField, + tuix.FloatField.createFloatField(builder, x.toFloat), + isNull) + case (null, DecimalType()) => + tuix.Field.createField( + builder, + tuix.FieldUnion.FloatField, + tuix.FloatField.createFloatField(builder, 0), + isNull) case (x: Double, DoubleType) => tuix.Field.createField( builder, @@ -779,6 +790,18 @@ object Utils extends Logging { op(fromChildren, tree) } + def getColType(dataType: DataType) = { + dataType match { + case IntegerType => tuix.ColType.IntegerType + case LongType => tuix.ColType.LongType + case FloatType => tuix.ColType.FloatType + case DecimalType() => tuix.ColType.FloatType + case DoubleType => tuix.ColType.DoubleType + case StringType => tuix.ColType.StringType + case _ => throw new OpaqueException("Type not supported: " + dataType.toString()) + } + } + /** Serialize an Expression into a tuix.Expr. Returns the offset of the written tuix.Expr. */ def flatbuffersSerializeExpression( builder: FlatBufferBuilder, expr: Expression, input: Seq[Attribute]): Int = { @@ -811,14 +834,7 @@ object Utils extends Logging { tuix.Cast.createCast( builder, childOffset, - dataType match { - case IntegerType => tuix.ColType.IntegerType - case LongType => tuix.ColType.LongType - case FloatType => tuix.ColType.FloatType - case DoubleType => tuix.ColType.DoubleType - case StringType => tuix.ColType.StringType - })) - + getColType(dataType))) // Arithmetic case (Add(left, right), Seq(leftOffset, rightOffset)) => tuix.Expr.createExpr( @@ -1087,6 +1103,17 @@ object Utils extends Logging { tuix.ExprUnion.ClosestPoint, tuix.ClosestPoint.createClosestPoint( builder, leftOffset, rightOffset)) + + case (PromotePrecision(child), Seq(childOffset)) => + // TODO: Implement decimal serialization, followed by PromotePrecision + childOffset + + case (CheckOverflow(child, dataType, _), Seq(childOffset)) => + // TODO: Implement decimal serialization, followed by CheckOverflow + childOffset + + case (_, Seq(childOffset)) => + throw new OpaqueException("Expression not supported: " + expr.toString()) } } } diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala index c32eb8436b..ed8da375c5 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/TPCHTests.scala @@ -104,8 +104,8 @@ trait TPCHTests extends OpaqueTestsBase { self => tpch.query(19, securityLevel, spark.sqlContext, numPartitions).collect.toSet } - testAgainstSpark("TPC-H 20", ignore) { securityLevel => - tpch.query(20, securityLevel, spark.sqlContext, numPartitions).collect + testAgainstSpark("TPC-H 20") { securityLevel => + tpch.query(20, securityLevel, spark.sqlContext, numPartitions).collect.toSet } testAgainstSpark("TPC-H 21", ignore) { securityLevel =>