diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 011cd72cbb9a9..f8fc9fcdbbbbb 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -77,7 +77,9 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { fn reserved(&self) -> usize; } -/// A memory consumer that can be tracked by [`MemoryReservation`] in a [`MemoryPool`] +/// A memory consumer that can be tracked by [`MemoryReservation`] in +/// a [`MemoryPool`]. All allocations are registered to a particular +/// `MemoryConsumer`; #[derive(Debug)] pub struct MemoryConsumer { name: String, @@ -113,20 +115,40 @@ impl MemoryConsumer { pub fn register(self, pool: &Arc) -> MemoryReservation { pool.register(&self); MemoryReservation { - consumer: self, + registration: Arc::new(SharedRegistration { + pool: Arc::clone(pool), + consumer: self, + }), size: 0, - policy: Arc::clone(pool), } } } -/// A [`MemoryReservation`] tracks a reservation of memory in a [`MemoryPool`] -/// that is freed back to the pool on drop +/// A registration of a [`MemoryConsumer`] with a [`MemoryPool`]. +/// +/// Calls [`MemoryPool::unregister`] on drop to return any memory to +/// the underlying pool. #[derive(Debug)] -pub struct MemoryReservation { +struct SharedRegistration { + pool: Arc, consumer: MemoryConsumer, +} + +impl Drop for SharedRegistration { + fn drop(&mut self) { + self.pool.unregister(&self.consumer); + } +} + +/// A [`MemoryReservation`] tracks an individual reservation of a +/// number of bytes of memory in a [`MemoryPool`] that is freed back +/// to the pool on drop. +/// +/// The reservation can be grown or shrunk over time. +#[derive(Debug)] +pub struct MemoryReservation { + registration: Arc, size: usize, - policy: Arc, } impl MemoryReservation { @@ -135,7 +157,8 @@ impl MemoryReservation { self.size } - /// Frees all bytes from this reservation returning the number of bytes freed + /// Frees all bytes from this reservation back to the underlying + /// pool, returning the number of bytes freed. pub fn free(&mut self) -> usize { let size = self.size; if size != 0 { @@ -151,7 +174,7 @@ impl MemoryReservation { /// Panics if `capacity` exceeds [`Self::size`] pub fn shrink(&mut self, capacity: usize) { let new_size = self.size.checked_sub(capacity).unwrap(); - self.policy.shrink(self, capacity); + self.registration.pool.shrink(self, capacity); self.size = new_size } @@ -176,22 +199,55 @@ impl MemoryReservation { /// Increase the size of this reservation by `capacity` bytes pub fn grow(&mut self, capacity: usize) { - self.policy.grow(self, capacity); + self.registration.pool.grow(self, capacity); self.size += capacity; } - /// Try to increase the size of this reservation by `capacity` bytes + /// Try to increase the size of this reservation by `capacity` + /// bytes, returning error if there is insufficient capacity left + /// in the pool. pub fn try_grow(&mut self, capacity: usize) -> Result<()> { - self.policy.try_grow(self, capacity)?; + self.registration.pool.try_grow(self, capacity)?; self.size += capacity; Ok(()) } + + /// Splits off `capacity` bytes from this [`MemoryReservation`] + /// into a new [`MemoryReservation`] with the same + /// [`MemoryConsumer`]. + /// + /// This can be useful to free part of this reservation with RAAI + /// style dropping + /// + /// # Panics + /// + /// Panics if `capacity` exceeds [`Self::size`] + pub fn split(&mut self, capacity: usize) -> MemoryReservation { + self.size = self.size.checked_sub(capacity).unwrap(); + Self { + size: capacity, + registration: self.registration.clone(), + } + } + + /// Returns a new empty [`MemoryReservation`] with the same [`MemoryConsumer`] + pub fn new_empty(&self) -> Self { + Self { + size: 0, + registration: self.registration.clone(), + } + } + + /// Splits off all the bytes from this [`MemoryReservation`] into + /// a new [`MemoryReservation`] with the same [`MemoryConsumer`] + pub fn take(&mut self) -> MemoryReservation { + self.split(self.size) + } } impl Drop for MemoryReservation { fn drop(&mut self) { self.free(); - self.policy.unregister(&self.consumer); } } @@ -251,4 +307,59 @@ mod tests { a2.try_grow(25).unwrap(); assert_eq!(pool.reserved(), 25); } + + #[test] + fn test_split() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(20).unwrap(); + assert_eq!(r1.size(), 20); + assert_eq!(pool.reserved(), 20); + + // take 5 from r1, should still have same reservation split + let r2 = r1.split(5); + assert_eq!(r1.size(), 15); + assert_eq!(r2.size(), 5); + assert_eq!(pool.reserved(), 20); + + // dropping r1 frees 15 but retains 5 as they have the same consumer + drop(r1); + assert_eq!(r2.size(), 5); + assert_eq!(pool.reserved(), 5); + } + + #[test] + fn test_new_empty() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(20).unwrap(); + let mut r2 = r1.new_empty(); + r2.try_grow(5).unwrap(); + + assert_eq!(r1.size(), 20); + assert_eq!(r2.size(), 5); + assert_eq!(pool.reserved(), 25); + } + + #[test] + fn test_take() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(20).unwrap(); + let mut r2 = r1.take(); + r2.try_grow(5).unwrap(); + + assert_eq!(r1.size(), 0); + assert_eq!(r2.size(), 25); + assert_eq!(pool.reserved(), 25); + + // r1 can still grow again + r1.try_grow(3).unwrap(); + assert_eq!(r1.size(), 3); + assert_eq!(r2.size(), 25); + assert_eq!(pool.reserved(), 28); + } } diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index 7b68a86244b70..1242ce025ca2c 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -84,7 +84,11 @@ impl MemoryPool for GreedyMemoryPool { (new_used <= self.pool_size).then_some(new_used) }) .map_err(|used| { - insufficient_capacity_err(reservation, additional, self.pool_size - used) + insufficient_capacity_err( + reservation, + additional, + self.pool_size.saturating_sub(used), + ) })?; Ok(()) } @@ -159,13 +163,14 @@ impl MemoryPool for FairSpillPool { fn unregister(&self, consumer: &MemoryConsumer) { if consumer.can_spill { - self.state.lock().num_spill -= 1; + let mut state = self.state.lock(); + state.num_spill = state.num_spill.checked_sub(1).unwrap(); } } fn grow(&self, reservation: &MemoryReservation, additional: usize) { let mut state = self.state.lock(); - match reservation.consumer.can_spill { + match reservation.registration.consumer.can_spill { true => state.spillable += additional, false => state.unspillable += additional, } @@ -173,7 +178,7 @@ impl MemoryPool for FairSpillPool { fn shrink(&self, reservation: &MemoryReservation, shrink: usize) { let mut state = self.state.lock(); - match reservation.consumer.can_spill { + match reservation.registration.consumer.can_spill { true => state.spillable -= shrink, false => state.unspillable -= shrink, } @@ -182,7 +187,7 @@ impl MemoryPool for FairSpillPool { fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { let mut state = self.state.lock(); - match reservation.consumer.can_spill { + match reservation.registration.consumer.can_spill { true => { // The total amount of memory available to spilling consumers let spill_available = self.pool_size.saturating_sub(state.unspillable); @@ -230,7 +235,7 @@ fn insufficient_capacity_err( additional: usize, available: usize, ) -> DataFusionError { - DataFusionError::ResourcesExhausted(format!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.consumer.name, reservation.size, available)) + DataFusionError::ResourcesExhausted(format!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.registration.consumer.name, reservation.size, available)) } #[cfg(test)] @@ -247,7 +252,7 @@ mod tests { r1.grow(2000); assert_eq!(pool.reserved(), 2000); - let mut r2 = MemoryConsumer::new("s1") + let mut r2 = MemoryConsumer::new("r2") .with_can_spill(true) .register(&pool); // Can grow beyond capacity of pool @@ -256,10 +261,10 @@ mod tests { assert_eq!(pool.reserved(), 4000); let err = r2.try_grow(1).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for s1 with 2000 bytes already allocated - maximum available is 0"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); let err = r2.try_grow(1).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for s1 with 2000 bytes already allocated - maximum available is 0"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); r1.shrink(1990); r2.shrink(2000); @@ -269,7 +274,7 @@ mod tests { r1.try_grow(10).unwrap(); assert_eq!(pool.reserved(), 20); - // Can grow a2 to 80 as only spilling consumer + // Can grow r2 to 80 as only spilling consumer r2.try_grow(80).unwrap(); assert_eq!(pool.reserved(), 100); @@ -279,19 +284,19 @@ mod tests { assert_eq!(r2.size(), 10); assert_eq!(pool.reserved(), 30); - let mut r3 = MemoryConsumer::new("s2") + let mut r3 = MemoryConsumer::new("r3") .with_can_spill(true) .register(&pool); let err = r3.try_grow(70).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for s2 with 0 bytes already allocated - maximum available is 40"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); - //Shrinking a2 to zero doesn't allow a3 to allocate more than 45 + //Shrinking r2 to zero doesn't allow a3 to allocate more than 45 r2.free(); let err = r3.try_grow(70).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for s2 with 0 bytes already allocated - maximum available is 40"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); - // But dropping a2 does + // But dropping r2 does drop(r2); assert_eq!(pool.reserved(), 20); r3.try_grow(80).unwrap();