Skip to content

Commit e980ef8

Browse files
seddonm1alamb
authored andcommitted
ARROW-10817: [Rust] [DataFusion] Implement TypedString and DATE coercion
This PR adds support for what the `sqlparser` crate calls `TypedString` which is basically syntactic sugar for an inline-cast. As this was an effort to get the `TPC-H` queries behaving correctly I then went a step further and added support for `Date` (temporal) coercion. I can split this PR if needed. ```sql where l_shipdate <= date '1998-09-02' ``` is equivalent to ```sql where l_shipdate <= CAST('1998-09-02' AS DATE) ``` FYI I am planning to tackle `INTERVAL` next. Closes #8892 from seddonm1/typed_string Authored-by: Mike Seddon <seddonm1@gmail.com> Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 989757f commit e980ef8

4 files changed

Lines changed: 185 additions & 24 deletions

File tree

rust/arrow/src/compute/kernels/cast.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
7272
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,
7373

7474
(Utf8, Date32(DateUnit::Day)) => true,
75+
(Utf8, Date64(DateUnit::Millisecond)) => true,
7576
(Utf8, _) => DataType::is_numeric(to_type),
7677
(_, Utf8) => DataType::is_numeric(from_type) || from_type == &Binary,
7778

@@ -399,6 +400,26 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
399400
}
400401
Ok(Arc::new(builder.finish()) as ArrayRef)
401402
}
403+
Date64(DateUnit::Millisecond) => {
404+
use chrono::{NaiveDate, NaiveTime};
405+
let zero_time = NaiveTime::from_hms(0, 0, 0);
406+
let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
407+
let mut builder = PrimitiveBuilder::<Date64Type>::new(string_array.len());
408+
for i in 0..string_array.len() {
409+
if string_array.is_null(i) {
410+
builder.append_null()?;
411+
} else {
412+
match NaiveDate::parse_from_str(string_array.value(i), "%Y-%m-%d")
413+
{
414+
Ok(date) => builder.append_value(
415+
date.and_time(zero_time).timestamp_millis() as i64,
416+
)?,
417+
Err(_) => builder.append_null()?, // not a valid date
418+
};
419+
}
420+
}
421+
Ok(Arc::new(builder.finish()) as ArrayRef)
422+
}
402423
_ => Err(ArrowError::ComputeError(format!(
403424
"Casting from {:?} to {:?} not supported",
404425
from_type, to_type,
@@ -2780,6 +2801,31 @@ mod tests {
27802801
assert_eq!(false, c.is_valid(4)); // "2000"
27812802
}
27822803

2804+
#[test]
2805+
fn test_cast_utf8_to_date64() {
2806+
let a = StringArray::from(vec![
2807+
"2000-01-01", // valid date with leading 0s
2808+
"2000-2-2", // valid date without leading 0s
2809+
"2000-00-00", // invalid month and day
2810+
"2000-01-01T12:00:00", // date + time is invalid
2811+
"2000", // just a year is invalid
2812+
]);
2813+
let array = Arc::new(a) as ArrayRef;
2814+
let b = cast(&array, &DataType::Date64(DateUnit::Millisecond)).unwrap();
2815+
let c = b.as_any().downcast_ref::<Date64Array>().unwrap();
2816+
2817+
// test valid inputs
2818+
assert_eq!(true, c.is_valid(0)); // "2000-01-01"
2819+
assert_eq!(946684800000, c.value(0));
2820+
assert_eq!(true, c.is_valid(1)); // "2000-2-2"
2821+
assert_eq!(949449600000, c.value(1));
2822+
2823+
// test invalid inputs
2824+
assert_eq!(false, c.is_valid(2)); // "2000-00-00"
2825+
assert_eq!(false, c.is_valid(3)); // "2000-01-01T12:00:00"
2826+
assert_eq!(false, c.is_valid(4)); // "2000"
2827+
}
2828+
27832829
#[test]
27842830
fn test_can_cast_types() {
27852831
// this function attempts to ensure that can_cast_types stays

rust/benchmarks/src/bin/tpch.rs

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::path::PathBuf;
2121
use std::sync::Arc;
2222
use std::time::Instant;
2323

24-
use arrow::datatypes::{DataType, Field, Schema};
24+
use arrow::datatypes::{DataType, DateUnit, Field, Schema};
2525
use arrow::util::pretty;
2626
use datafusion::datasource::parquet::ParquetTable;
2727
use datafusion::datasource::{CsvFile, MemTable, TableProvider};
@@ -187,7 +187,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
187187
from
188188
lineitem
189189
where
190-
l_shipdate <= '1998-09-02'
190+
l_shipdate <= date '1998-09-02'
191191
group by
192192
l_returnflag,
193193
l_linestatus
@@ -256,8 +256,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
256256
c_mktsegment = 'BUILDING'
257257
and c_custkey = o_custkey
258258
and l_orderkey = o_orderkey
259-
and o_orderdate < '1995-03-15'
260-
and l_shipdate > '1995-03-15'
259+
and o_orderdate < date '1995-03-15'
260+
and l_shipdate > date '1995-03-15'
261261
group by
262262
l_orderkey,
263263
o_orderdate,
@@ -337,8 +337,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
337337
and s_nationkey = n_nationkey
338338
and n_regionkey = r_regionkey
339339
and r_name = 'ASIA'
340-
and o_orderdate >= '1994-01-01'
341-
and o_orderdate < '1995-01-01'
340+
and o_orderdate >= date '1994-01-01'
341+
and o_orderdate < date '1995-01-01'
342342
group by
343343
n_name
344344
order by
@@ -363,9 +363,9 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
363363
from
364364
lineitem
365365
where
366-
l_shipdate >= '1994-01-01'
367-
and l_shipdate < '1995-01-01'
368-
and l_discount between 0.06 - 0.01 and 0.06 + 0.01
366+
l_shipdate >= date '1994-01-01'
367+
and l_shipdate < date '1995-01-01'
368+
and l_discount > 0.06 - 0.01 and l_discount < 0.06 + 0.01
369369
and l_quantity < 24;"
370370
),
371371

@@ -399,7 +399,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
399399
(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
400400
or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
401401
)
402-
and l_shipdate > '1995-01-01' and l_shipdate < '1996-12-31'
402+
and l_shipdate > date '1995-01-01' and l_shipdate < date '1996-12-31'
403403
) as shipping
404404
group by
405405
supp_nation,
@@ -442,7 +442,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
442442
and n1.n_regionkey = r_regionkey
443443
and r_name = 'AMERICA'
444444
and s_nationkey = n2.n_nationkey
445-
and o_orderdate between '1995-01-01' and '1996-12-31'
445+
and o_orderdate between date '1995-01-01' and date '1996-12-31'
446446
and p_type = 'ECONOMY ANODIZED STEEL'
447447
) as all_nations
448448
group by
@@ -486,6 +486,39 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
486486
o_year desc;"
487487
),
488488

489+
// 10 => ctx.create_logical_plan(
490+
// "select
491+
// c_custkey,
492+
// c_name,
493+
// sum(l_extendedprice * (1 - l_discount)) as revenue,
494+
// c_acctbal,
495+
// n_name,
496+
// c_address,
497+
// c_phone,
498+
// c_comment
499+
// from
500+
// customer,
501+
// orders,
502+
// lineitem,
503+
// nation
504+
// where
505+
// c_custkey = o_custkey
506+
// and l_orderkey = o_orderkey
507+
// and o_orderdate >= date '1993-10-01'
508+
// and o_orderdate < date '1993-10-01' + interval '3' month
509+
// and l_returnflag = 'R'
510+
// and c_nationkey = n_nationkey
511+
// group by
512+
// c_custkey,
513+
// c_name,
514+
// c_acctbal,
515+
// c_phone,
516+
// n_name,
517+
// c_address,
518+
// c_comment
519+
// order by
520+
// revenue desc;"
521+
// ),
489522
10 => ctx.create_logical_plan(
490523
"select
491524
c_custkey,
@@ -504,8 +537,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
504537
where
505538
c_custkey = o_custkey
506539
and l_orderkey = o_orderkey
507-
and o_orderdate >= '1993-10-01'
508-
and o_orderdate < '1994-01-01'
540+
and o_orderdate >= date '1993-10-01'
541+
and o_orderdate < date '1994-01-01'
509542
and l_returnflag = 'R'
510543
and c_nationkey = n_nationkey
511544
group by
@@ -606,8 +639,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
606639
(l_shipmode = 'MAIL' or l_shipmode = 'SHIP')
607640
and l_commitdate < l_receiptdate
608641
and l_shipdate < l_commitdate
609-
and l_receiptdate >= '1994-01-01'
610-
and l_receiptdate < '1995-01-01'
642+
and l_receiptdate >= date '1994-01-01'
643+
and l_receiptdate < date '1995-01-01'
611644
group by
612645
l_shipmode
613646
order by
@@ -649,8 +682,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<Logic
649682
part
650683
where
651684
l_partkey = p_partkey
652-
and l_shipdate >= '1995-09-01'
653-
and l_shipdate < '1995-10-01';"
685+
and l_shipdate >= date '1995-09-01'
686+
and l_shipdate < date '1995-10-01';"
654687
),
655688

656689
15 => ctx.create_logical_plan(
@@ -1072,7 +1105,7 @@ fn get_schema(table: &str) -> Schema {
10721105
Field::new("o_custkey", DataType::UInt32, false),
10731106
Field::new("o_orderstatus", DataType::Utf8, false),
10741107
Field::new("o_totalprice", DataType::Float64, false), // decimal
1075-
Field::new("o_orderdate", DataType::Utf8, false),
1108+
Field::new("o_orderdate", DataType::Date32(DateUnit::Day), false),
10761109
Field::new("o_orderpriority", DataType::Utf8, false),
10771110
Field::new("o_clerk", DataType::Utf8, false),
10781111
Field::new("o_shippriority", DataType::UInt32, false),
@@ -1090,9 +1123,9 @@ fn get_schema(table: &str) -> Schema {
10901123
Field::new("l_tax", DataType::Float64, false), // decimal
10911124
Field::new("l_returnflag", DataType::Utf8, false),
10921125
Field::new("l_linestatus", DataType::Utf8, false),
1093-
Field::new("l_shipdate", DataType::Utf8, false),
1094-
Field::new("l_commitdate", DataType::Utf8, false),
1095-
Field::new("l_receiptdate", DataType::Utf8, false),
1126+
Field::new("l_shipdate", DataType::Date32(DateUnit::Day), false),
1127+
Field::new("l_commitdate", DataType::Date32(DateUnit::Day), false),
1128+
Field::new("l_receiptdate", DataType::Date32(DateUnit::Day), false),
10961129
Field::new("l_shipinstruct", DataType::Utf8, false),
10971130
Field::new("l_shipmode", DataType::Utf8, false),
10981131
Field::new("l_comment", DataType::Utf8, false),

rust/datafusion/src/physical_plan/expressions.rs

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ use arrow::datatypes::{DataType, DateUnit, Schema, TimeUnit};
4848
use arrow::record_batch::RecordBatch;
4949
use arrow::{
5050
array::{
51-
ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array,
52-
Int32Array, Int64Array, Int8Array, StringArray, TimestampNanosecondArray,
53-
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
51+
ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array,
52+
Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
53+
TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
5454
},
5555
datatypes::Field,
5656
};
@@ -1135,6 +1135,9 @@ macro_rules! binary_array_op {
11351135
DataType::Date32(DateUnit::Day) => {
11361136
compute_op!($LEFT, $RIGHT, $OP, Date32Array)
11371137
}
1138+
DataType::Date64(DateUnit::Millisecond) => {
1139+
compute_op!($LEFT, $RIGHT, $OP, Date64Array)
1140+
}
11381141
other => Err(DataFusionError::Internal(format!(
11391142
"Unsupported data type {:?}",
11401143
other
@@ -1227,6 +1230,19 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
12271230
}
12281231
}
12291232

1233+
/// Coercion rules for Temporal columns: the type that both lhs and rhs can be
1234+
/// casted to for the purpose of a date computation
1235+
fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1236+
use arrow::datatypes::DataType::*;
1237+
match (lhs_type, rhs_type) {
1238+
(Utf8, Date32(DateUnit::Day)) => Some(Date32(DateUnit::Day)),
1239+
(Date32(DateUnit::Day), Utf8) => Some(Date32(DateUnit::Day)),
1240+
(Utf8, Date64(DateUnit::Millisecond)) => Some(Date64(DateUnit::Millisecond)),
1241+
(Date64(DateUnit::Millisecond), Utf8) => Some(Date64(DateUnit::Millisecond)),
1242+
_ => None,
1243+
}
1244+
}
1245+
12301246
/// Coercion rule for numerical types: The type that both lhs and rhs
12311247
/// can be casted to for numerical calculation, while maintaining
12321248
/// maximum precision
@@ -1288,6 +1304,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
12881304
}
12891305
numerical_coercion(lhs_type, rhs_type)
12901306
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
1307+
.or_else(|| temporal_coercion(lhs_type, rhs_type))
12911308
}
12921309

12931310
// coercion rules that assume an ordered set, such as "less than".
@@ -1301,6 +1318,7 @@ fn order_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
13011318
numerical_coercion(lhs_type, rhs_type)
13021319
.or_else(|| string_coercion(lhs_type, rhs_type))
13031320
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
1321+
.or_else(|| temporal_coercion(lhs_type, rhs_type))
13041322
}
13051323

13061324
/// Coercion rules for all binary operators. Returns the output type
@@ -2638,6 +2656,54 @@ mod tests {
26382656
DataType::Boolean,
26392657
vec![true, false]
26402658
);
2659+
test_coercion!(
2660+
StringArray,
2661+
DataType::Utf8,
2662+
vec!["1994-12-13", "1995-01-26"],
2663+
Date32Array,
2664+
DataType::Date32(DateUnit::Day),
2665+
vec![9112, 9156],
2666+
Operator::Eq,
2667+
BooleanArray,
2668+
DataType::Boolean,
2669+
vec![true, true]
2670+
);
2671+
test_coercion!(
2672+
StringArray,
2673+
DataType::Utf8,
2674+
vec!["1994-12-13", "1995-01-26"],
2675+
Date32Array,
2676+
DataType::Date32(DateUnit::Day),
2677+
vec![9113, 9154],
2678+
Operator::Lt,
2679+
BooleanArray,
2680+
DataType::Boolean,
2681+
vec![true, false]
2682+
);
2683+
test_coercion!(
2684+
StringArray,
2685+
DataType::Utf8,
2686+
vec!["1994-12-13", "1995-01-26"],
2687+
Date64Array,
2688+
DataType::Date64(DateUnit::Millisecond),
2689+
vec![787276800000, 791078400000],
2690+
Operator::Eq,
2691+
BooleanArray,
2692+
DataType::Boolean,
2693+
vec![true, true]
2694+
);
2695+
test_coercion!(
2696+
StringArray,
2697+
DataType::Utf8,
2698+
vec!["1994-12-13", "1995-01-26"],
2699+
Date64Array,
2700+
DataType::Date64(DateUnit::Millisecond),
2701+
vec![787276800001, 791078399999],
2702+
Operator::Lt,
2703+
BooleanArray,
2704+
DataType::Boolean,
2705+
vec![true, false]
2706+
);
26412707
Ok(())
26422708
}
26432709

rust/datafusion/src/sql/planner.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,14 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
629629
data_type: convert_data_type(data_type)?,
630630
}),
631631

632+
SQLExpr::TypedString {
633+
ref data_type,
634+
ref value,
635+
} => Ok(Expr::Cast {
636+
expr: Box::new(lit(&**value)),
637+
data_type: convert_data_type(data_type)?,
638+
}),
639+
632640
SQLExpr::IsNull(ref expr) => {
633641
Ok(Expr::IsNull(Box::new(self.sql_to_rex(expr, schema)?)))
634642
}
@@ -1311,6 +1319,14 @@ mod tests {
13111319
quick_test(sql, expected);
13121320
}
13131321

1322+
#[test]
1323+
fn select_typedstring() {
1324+
let sql = "SELECT date '2020-12-10' AS date FROM person";
1325+
let expected = "Projection: CAST(Utf8(\"2020-12-10\") AS Date32(Day)) AS date\
1326+
\n TableScan: person projection=None";
1327+
quick_test(sql, expected);
1328+
}
1329+
13141330
fn logical_plan(sql: &str) -> Result<LogicalPlan> {
13151331
let planner = SqlToRel::new(&MockSchemaProvider {});
13161332
let result = DFParser::parse_sql(&sql);

0 commit comments

Comments
 (0)