Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,25 @@ mod tests {
"+-----+-------------+",
];
assert_batches_sorted_eq!(expected, &results);

// Now, use dict as an aggregate
let results = plan_and_collect(
&mut ctx,
"SELECT val, count(distinct dict) FROM t GROUP BY val",
)
.await
.expect("ran plan correctly");

let expected = vec![
"+-----+----------------------+",
"| val | COUNT(DISTINCT dict) |",
"+-----+----------------------+",
"| 1 | 2 |",
"| 2 | 2 |",
"| 4 | 1 |",
"+-----+----------------------+",
];
assert_batches_sorted_eq!(expected, &results);
}

run_test_case::<Int8Type>().await;
Expand Down
37 changes: 27 additions & 10 deletions datafusion/src/physical_plan/distinct_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ pub struct DistinctCount {
name: String,
/// The DataType for the final count
data_type: DataType,
/// The DataType for each input argument
input_data_types: Vec<DataType>,
/// The DataType used to hold the state for each input
state_data_types: Vec<DataType>,
/// The input arguments
exprs: Vec<Arc<dyn PhysicalExpr>>,
}
Expand All @@ -61,15 +61,26 @@ impl DistinctCount {
name: String,
data_type: DataType,
) -> Self {
let state_data_types = input_data_types.into_iter().map(state_type).collect();

Self {
input_data_types,
state_data_types,
exprs,
name,
data_type,
}
}
}

/// return the type to use to accumulate state for the specified input type
fn state_type(data_type: DataType) -> DataType {
match data_type {
// when aggregating dictionary values, use the underlying value type
DataType::Dictionary(_key_type, value_type) => *value_type,
t => t,
}
}

impl AggregateExpr for DistinctCount {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
Expand All @@ -82,12 +93,16 @@ impl AggregateExpr for DistinctCount {

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(self
.input_data_types
.state_data_types
.iter()
.map(|data_type| {
.map(|state_data_type| {
Field::new(
&format_state_name(&self.name, "count distinct"),
DataType::List(Box::new(Field::new("item", data_type.clone(), true))),
DataType::List(Box::new(Field::new(
"item",
state_data_type.clone(),
true,
))),
false,
)
})
Expand All @@ -101,7 +116,7 @@ impl AggregateExpr for DistinctCount {
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
data_types: self.input_data_types.clone(),
state_data_types: self.state_data_types.clone(),
count_data_type: self.data_type.clone(),
}))
}
Expand All @@ -110,7 +125,7 @@ impl AggregateExpr for DistinctCount {
#[derive(Debug)]
struct DistinctCountAccumulator {
values: HashSet<DistinctScalarValues, RandomState>,
data_types: Vec<DataType>,
state_data_types: Vec<DataType>,
count_data_type: DataType,
}

Expand Down Expand Up @@ -156,9 +171,11 @@ impl Accumulator for DistinctCountAccumulator {

fn state(&self) -> Result<Vec<ScalarValue>> {
let mut cols_out = self
.data_types
.state_data_types
.iter()
.map(|data_type| ScalarValue::List(Some(Vec::new()), data_type.clone()))
.map(|state_data_type| {
ScalarValue::List(Some(Vec::new()), state_data_type.clone())
})
.collect::<Vec<_>>();

let mut cols_vec = cols_out
Expand Down
1 change: 1 addition & 0 deletions datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ pub trait AggregateExpr: Send + Sync + Debug {
/// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be
/// downcast to a specific implementation.
fn as_any(&self) -> &dyn Any;

/// the field of the final result of this aggregation.
fn field(&self) -> Result<Field>;

Expand Down
48 changes: 45 additions & 3 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@

use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};

use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use arrow::datatypes::{ArrowDictionaryKeyType, DataType, Field, IntervalUnit, TimeUnit};
use arrow::{
array::*,
datatypes::{ArrowNativeType, Float32Type, TimestampNanosecondType},
datatypes::{
ArrowNativeType, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type,
TimestampNanosecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
},
};
use arrow::{
array::{
Expand Down Expand Up @@ -444,14 +447,53 @@ impl ScalarValue {
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
typed_cast!(array, index, TimestampNanosecondArray, TimestampNanosecond)
}
DataType::Dictionary(index_type, _) => match **index_type {
DataType::Int8 => Self::try_from_dict_array::<Int8Type>(array, index)?,
DataType::Int16 => Self::try_from_dict_array::<Int16Type>(array, index)?,
DataType::Int32 => Self::try_from_dict_array::<Int32Type>(array, index)?,
DataType::Int64 => Self::try_from_dict_array::<Int64Type>(array, index)?,
DataType::UInt8 => Self::try_from_dict_array::<UInt8Type>(array, index)?,
DataType::UInt16 => {
Self::try_from_dict_array::<UInt16Type>(array, index)?
}
DataType::UInt32 => {
Self::try_from_dict_array::<UInt32Type>(array, index)?
}
DataType::UInt64 => {
Self::try_from_dict_array::<UInt64Type>(array, index)?
}
_ => {
return Err(DataFusionError::Internal(format!(
"Index type not supported while creating scalar from dictionary: {}",
array.data_type(),
)))
}
},
other => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar of array of type \"{:?}\"",
"Can't create a scalar from array of type \"{:?}\"",
other
)))
}
})
}

fn try_from_dict_array<K: ArrowDictionaryKeyType>(
array: &ArrayRef,
index: usize,
) -> Result<Self> {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
let keys_col = dict_array.keys_array();
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
keys_col.data_type()
))
})?;
Self::try_from_array(&dict_array.values(), values_index)
}
}

impl From<f64> for ScalarValue {
Expand Down