diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 2c73a80f64ebf..826e3b76f9bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -122,6 +122,10 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val EXCEPT = Keyword("EXCEPT") protected val SUBSTR = Keyword("SUBSTR") protected val SUBSTRING = Keyword("SUBSTRING") + protected val LEN = Keyword("LEN") + protected val LENGTH = Keyword("LENGTH") + protected val CHAR_LEN = Keyword("CHAR_LEN") + protected val OCTET_LEN = Keyword("OCTET_LEN") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -323,6 +327,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) } | + (LEN | LENGTH | CHAR_LEN) ~> "(" ~> expression <~ ")" ^^ { case s => Length(s) } | + OCTET_LEN ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { + case s ~ "," ~ e => OctetLength(s, e) + } | + OCTET_LEN ~> "(" ~> expression <~ ")" ^^ { + case s => OctetLength(s, Literal(OctetLengthConstants.DefaultEncoding)) + } | ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ba62dabe3dd6a..f3a7233b0b46d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -37,7 +37,7 @@ abstract class Expression extends TreeNode[Expression] { * - A [[BinaryExpression]] is foldable if its both left and right child are foldable * - A [[Not]], [[IsNull]], or [[IsNotNull]] is foldable if its child is foldable * - A [[Literal]] is foldable - * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable + * - A [[Cast]] or [[UnaryMinus]] or [[Length/Octetlen]] is foldable if its child is foldable */ def foldable: Boolean = false def nullable: Boolean diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 97fc3a3b14b88..65be1a76fa35d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.catalyst.expressions +import java.io.UnsupportedEncodingException import java.util.regex.Pattern +import org.apache.spark.Logging + import scala.collection.IndexedSeqOptimized import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.types.{BinaryType, BooleanType, DataType, StringType} +import org.apache.spark.sql.catalyst.types.{BinaryType, BooleanType, DataType, StringType, IntegerType} trait StringRegexExpression { self: BinaryExpression => @@ -208,6 +211,83 @@ case class EndsWith(left: Expression, right: Expression) def compare(l: String, r: String) = l.endsWith(r) } + +/** + * A function that returns the number of bytes in an expression + */ +case class Length(child: Expression) extends UnaryExpression { + + type EvaluatedType = Any + + override def dataType = IntegerType + + override def foldable = child.foldable + + override def nullable = child.nullable + + override def toString = s"Length($child)" + + override def eval(input: Row): EvaluatedType = { + val inputVal = child.eval(input) + if (inputVal == null) { + null + } else if (!inputVal.isInstanceOf[String]) { + inputVal.toString.length + } else { + val str = inputVal.asInstanceOf[String] + str.codePointCount(0, str.length) + } + } +} + +object OctetLengthConstants { + val DefaultEncoding = "UTF-8" +} + +/** + * A function that returns the number of characters in a string expression + */ +case class OctetLength(child: Expression, encoding : Expression) extends UnaryExpression + with Logging { + + type EvaluatedType = Any + + override def dataType = IntegerType + + override def foldable = child.foldable + + override def nullable = true + + override def toString = s"OctetLen($child, $encoding)" + + override def eval(input: Row): EvaluatedType = { + val evalInput = child.eval(input) + if (evalInput == null) { + null + } else if (!evalInput.isInstanceOf[String]) { + log.debug(s"Non-string value [$evalInput] provided to OctetLen") + null + } else { + var evalEncoding = encoding.eval(input) + val strEncoding = + if (evalEncoding != null) { + evalEncoding.toString + } else { + OctetLengthConstants.DefaultEncoding + } + val s: String = "" + try { + evalInput.asInstanceOf[String].getBytes(strEncoding).length + } catch { + case ue : UnsupportedEncodingException => { + throw new UnsupportedEncodingException( + s"OctetLen: Caught UnsupportedEncodingException for encoding=[$strEncoding]") + } + } + } + } +} + /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5f86d6047cb9c..02b81b799e4c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -174,6 +174,9 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType) case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType) + case e @ Length(Literal(null, _)) => Literal(null, e.dataType) + case e @ OctetLength(Literal(null, _),_) => Literal(null, e.dataType) + // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 999c9fff38d60..bc8f1cdfe8fc7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -567,4 +567,42 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(s.substring(0, 2), "ex", row) checkEvaluation(s.substring(0), "example", row) } + + test("Length") { + checkEvaluation(Length(Literal(null, IntegerType)), null) + checkEvaluation(Length(Literal(0, IntegerType)), 1) + checkEvaluation(Length(Literal(12, IntegerType)), 2) + checkEvaluation(Length(Literal(123, IntegerType)), 3) + checkEvaluation(Length(Literal(12.4F, FloatType)), 4) + checkEvaluation(Length(Literal(12345678901L, LongType)), 11) + checkEvaluation(Length(Literal(1234567890.2D, DoubleType)), 14) + checkEvaluation(Length(Literal("1234567890ABC", StringType)), 13) + checkEvaluation(Length(Literal("\uF93D\uF936\uF949\uF942", StringType)), 4) + } + + test("OctetLen") { + checkEvaluation(OctetLength(Literal(null, StringType), "ISO-8859-1"), null) + checkEvaluation(OctetLength(Literal(null, StringType), "UTF-8"), null) + checkEvaluation(OctetLength(Literal(null, StringType), "UTF-16"), null) + checkEvaluation(OctetLength(Literal("1234567890ABC", StringType), "ISO-8859-1"), 13) + checkEvaluation(OctetLength(Literal("1234567890ABC", StringType), "UTF-8"), 13) + checkEvaluation(OctetLength(Literal("1234567890ABC", StringType), "UTF-16"), 28) + checkEvaluation(OctetLength(Literal("1234567890ABC", StringType), "UTF-32"), 52) + checkEvaluation(OctetLength( + Literal("\uF93D\uF936\uF949\uF942", StringType), "ISO-8859-1"), 4) + // Chinese characters get truncated by ISO-8859-1 encoding + checkEvaluation(OctetLength( + Literal("\uF93D\uF936\uF949\uF942", StringType), "UTF-8"), 12) // chinese characters + checkEvaluation(OctetLength( + Literal("\uD840\uDC0B\uD842\uDFB7", StringType), "UTF-8"), 8) // 2 surrogate pairs + checkEvaluation(OctetLength( + Literal("\uF93D\uF936\uF949\uF942", StringType), "UTF-16"), 10) // chinese characters + checkEvaluation(OctetLength( + Literal("\uD840\uDC0B\uD842\uDFB7", StringType), "UTF-16"), 10) // 2 surrogate pairs + checkEvaluation(OctetLength( + Literal("\uF93D\uF936\uF949\uF942", StringType), "UTF-32"), 16) // chinese characters + checkEvaluation(OctetLength( + Literal("\uD840\uDC0B\uD842\uDFB7", StringType), "UTF-32"), 8) // 2 surrogate pairs + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 0a27cce337482..d6549e9de1d1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -208,7 +208,11 @@ class ConstantFoldingSuite extends PlanTest { Substring("abc", 0, Literal(null, IntegerType)) as 'c18, Contains(Literal(null, StringType), "abc") as 'c19, - Contains("abc", Literal(null, StringType)) as 'c20 + Contains("abc", Literal(null, StringType)) as 'c20, + + Length(Literal(null, IntegerType)) as 'c21, + OctetLength(Literal(null, StringType), "ISO-8859-1") as 'c22 + ) val optimized = Optimize(originalQuery.analyze) @@ -243,7 +247,11 @@ class ConstantFoldingSuite extends PlanTest { Literal(null, StringType) as 'c18, Literal(null, BooleanType) as 'c19, - Literal(null, BooleanType) as 'c20 + Literal(null, BooleanType) as 'c20, + + Literal(null, IntegerType) as 'c21, + Literal(null, IntegerType) as 'c22 + ).analyze comparePlans(optimized, correctAnswer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 1fd8d27b34c59..e7aa3265d7b60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.{PrintWriter, StringWriter} + import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5c571d35d1bb9..595b11c2a305d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -47,6 +47,34 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT substring(tableName, 3) FROM tableName"), "st") + checkAnswer( + sql("SELECT substring(tableName, 2) FROM tableName group by substring(tableName, 2)"), + "est") + } + + test("SPARK-2686 Added Parser of SQL LENGTH()") { + checkAnswer( + sql("SELECT char_len(key) as keylen from testData where key = 100"), 3) + checkAnswer( + sql("SELECT len(key), count(*) as cnt from testData where key <= 100 group by len(key)"), + Seq(Seq(1,9),Seq(2,90), Seq(3,1))) + checkAnswer( + sql("SELECT max(length(key * key) - len(key)) from testData where key <= 100"), 2) + checkAnswer( + sql("SELECT min(Length(s)) FROM nullableRepeatedData where s is not null"), 4) + checkAnswer( + sql("SELECT max(LENGTH(s)) FROM nullableRepeatedData"), 4) + } + + test("SPARK-2686 Added Parser of SQL OCTET_LEN()") { + checkAnswer( + sql("SELECT octet_len(s) from repeatedData"), Seq(Seq(4),Seq(4))) + checkAnswer( + sql("SELECT octet_len(s,'UTF-8') from repeatedData"), Seq(Seq(4),Seq(4))) + checkAnswer( + sql("SELECT max(octet_len(s,'UTF-8')) from nullStrings"), 3) + checkAnswer( + sql("SELECT octet_len('a','ISO-8859-1') + octet_len('abcde','ISO-8859-1') FROM testData limit 1"), 6) } test("index into array") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 3d2eb1eefaeda..e6769646e4af2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -866,6 +866,8 @@ private[hive] object HiveQl { val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r val SUBSTR = "(?i)SUBSTR(?:ING)?".r + val CHAR_LEN = "(?i)CHAR_LEN".r + val OCTET_LEN = "(?i)OCTET_LEN".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -997,6 +999,11 @@ private[hive] object HiveQl { Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) + case Token("TOK_FUNCTION", Token(OCTET_LEN(), Nil) :: string :: Nil) => + OctetLength(nodeToExpr(string), Literal(OctetLengthConstants.DefaultEncoding)) + case Token("TOK_FUNCTION", Token(OCTET_LEN(), Nil) :: string :: encoding :: Nil) => + OctetLength(nodeToExpr(string), nodeToExpr(encoding)) + /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>