Skip to content

Commit 4257c16

Browse files
committed
Add FunctionRegistry::register_udaf and FunctionRegistry::register_udwf
mo
1 parent c843226 commit 4257c16

2 files changed

Lines changed: 45 additions & 21 deletions

File tree

datafusion/core/src/execution/context/mod.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -823,15 +823,7 @@ impl SessionContext {
823823
/// Any functions registered with the udf name or its aliases will be overwritten with this new function
824824
pub fn register_udf(&self, f: ScalarUDF) {
825825
let mut state = self.state.write();
826-
let aliases = f.aliases();
827-
for alias in aliases {
828-
state
829-
.scalar_functions
830-
.insert(alias.to_string(), Arc::new(f.clone()));
831-
}
832-
state
833-
.scalar_functions
834-
.insert(f.name().to_string(), Arc::new(f));
826+
state.register_udf(Arc::new(f)).ok();
835827
}
836828

837829
/// Registers an aggregate UDF within this context.
@@ -842,10 +834,7 @@ impl SessionContext {
842834
/// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
843835
/// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
844836
pub fn register_udaf(&self, f: AggregateUDF) {
845-
self.state
846-
.write()
847-
.aggregate_functions
848-
.insert(f.name().to_string(), Arc::new(f));
837+
self.state.write().register_udaf(Arc::new(f)).ok();
849838
}
850839

851840
/// Registers a window UDF within this context.
@@ -856,10 +845,7 @@ impl SessionContext {
856845
/// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_udwf"`
857846
/// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"`
858847
pub fn register_udwf(&self, f: WindowUDF) {
859-
self.state
860-
.write()
861-
.window_functions
862-
.insert(f.name().to_string(), Arc::new(f));
848+
self.state.write().register_udwf(Arc::new(f)).ok();
863849
}
864850

865851
/// Creates a [`DataFrame`] for reading a data source.
@@ -1984,8 +1970,24 @@ impl FunctionRegistry for SessionState {
19841970
}
19851971

19861972
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
1973+
let aliases = udf.aliases();
1974+
for alias in aliases {
1975+
self.scalar_functions.insert(alias.to_string(), udf.clone());
1976+
}
1977+
19871978
Ok(self.scalar_functions.insert(udf.name().into(), udf))
19881979
}
1980+
1981+
fn register_udaf(
1982+
&mut self,
1983+
udaf: Arc<AggregateUDF>,
1984+
) -> Result<Option<Arc<AggregateUDF>>> {
1985+
Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
1986+
}
1987+
1988+
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
1989+
Ok(self.window_functions.insert(udwf.name().into(), udwf))
1990+
}
19891991
}
19901992

19911993
impl OptimizerConfig for SessionState {

datafusion/execution/src/registry.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ pub trait FunctionRegistry {
2727
/// Set of all available udfs.
2828
fn udfs(&self) -> HashSet<String>;
2929

30-
/// Returns a reference to the udf named `name`.
30+
/// Returns a reference to the user defined scalar function (udf) named
31+
/// `name`.
3132
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>>;
3233

33-
/// Returns a reference to the udaf named `name`.
34+
/// Returns a reference to the user defined table function (udaf) named
35+
/// `name`.
3436
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>>;
3537

36-
/// Returns a reference to the udwf named `name`.
38+
/// Returns a reference to the user defined window function (udwf) named
39+
/// `name`.
3740
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
3841

3942
/// Registers a new [`ScalarUDF`], returning any previously registered
@@ -45,7 +48,26 @@ pub trait FunctionRegistry {
4548
not_impl_err!("Registering ScalarUDF")
4649
}
4750

48-
// TODO add register_udaf and register_udwf
51+
/// Registers a new [`AggregateUDF`], returning any previously registered
52+
/// implementation.
53+
///
54+
/// Returns an error (the default) if the function can not be registered,
55+
/// for example if the registry is read only.
56+
fn register_udaf(
57+
&mut self,
58+
_udaf: Arc<AggregateUDF>,
59+
) -> Result<Option<Arc<AggregateUDF>>> {
60+
not_impl_err!("Registering AggregateUDF")
61+
}
62+
63+
/// Registers a new [`WindowUDF`], returning any previously registered
64+
/// implementation.
65+
///
66+
/// Returns an error (the default) if the function can not be registered,
67+
/// for example if the registry is read only.
68+
fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
69+
not_impl_err!("Registering WindowUDF")
70+
}
4971
}
5072

5173
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].

0 commit comments

Comments
 (0)