Skip to content

Commit 3de723f

Browse files
feat: Add fair unified memory pool (apache#1369)
Current Comet unified memory pool is a greedy pool. One thread (consumer) can take a large amount of memory that causes OOM for other threads, especially for aggregation. Added a fair version of unified memory pool similar to DataFusion `FairSpilPool` that caps the memory usage at `pool_size/num` The fair unified memory pool is the default for off-heap mode with this PR Exisiting tests
1 parent 2daf490 commit 3de723f

8 files changed

Lines changed: 186 additions & 5 deletions

File tree

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,8 @@ object CometConf extends ShimCometConf {
483483
.doc(
484484
"The type of memory pool to be used for Comet native execution. " +
485485
"Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', " +
486-
"'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, " +
487-
"this config is 'greedy_task_shared'.")
486+
"'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global'. For off-heap " +
487+
"types are 'unified' and `fair_unified`.")
488488
.stringConf
489489
.createWithDefault("greedy_task_shared")
490490

docs/source/user-guide/configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Comet provides the following configuration settings.
4848
| spark.comet.exec.hashJoin.enabled | Whether to enable hashJoin by default. | true |
4949
| spark.comet.exec.initCap.enabled | Whether to enable initCap by default. | false |
5050
| spark.comet.exec.localLimit.enabled | Whether to enable localLimit by default. | true |
51-
| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By default, this config is 'greedy_task_shared'. | greedy_task_shared |
51+
| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet native execution. Available memory pool types are 'greedy', 'fair_spill', 'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global'. For off-heap types are 'unified' and `fair_unified`. | greedy_task_shared |
5252
| spark.comet.exec.project.enabled | Whether to enable project by default. | true |
5353
| spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin for improved performance. This feature is not stable yet. For more information, refer to the Comet Tuning Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false |
5454
| spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. lz4, zstd, and snappy are supported. Compression can be disabled by setting spark.shuffle.compress=false. | lz4 |

native/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ crc32fast = "1.3.2"
7474
simd-adler32 = "0.3.7"
7575
datafusion-comet-spark-expr = { workspace = true }
7676
datafusion-comet-proto = { workspace = true }
77+
parking_lot = "0.12.3"
7778

7879
[dev-dependencies]
7980
pprof = { version = "0.13.0", features = ["flamegraph"] }
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::{
19+
fmt::{Debug, Formatter, Result as FmtResult},
20+
sync::Arc,
21+
};
22+
23+
use jni::objects::GlobalRef;
24+
25+
use crate::{
26+
errors::CometResult,
27+
jvm_bridge::{jni_call, JVMClasses},
28+
};
29+
use datafusion::{
30+
common::DataFusionError,
31+
execution::memory_pool::{MemoryPool, MemoryReservation},
32+
};
33+
use datafusion_common::resources_err;
34+
use datafusion_execution::memory_pool::MemoryConsumer;
35+
use parking_lot::Mutex;
36+
37+
/// A DataFusion fair `MemoryPool` implementation for Comet. Internally this is
38+
/// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`].
39+
pub struct CometFairMemoryPool {
40+
task_memory_manager_handle: Arc<GlobalRef>,
41+
pool_size: usize,
42+
state: Mutex<CometFairPoolState>,
43+
}
44+
45+
struct CometFairPoolState {
46+
used: usize,
47+
num: usize,
48+
}
49+
50+
impl Debug for CometFairMemoryPool {
51+
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
52+
let state = self.state.lock();
53+
f.debug_struct("CometFairMemoryPool")
54+
.field("pool_size", &self.pool_size)
55+
.field("used", &state.used)
56+
.field("num", &state.num)
57+
.finish()
58+
}
59+
}
60+
61+
impl CometFairMemoryPool {
62+
pub fn new(
63+
task_memory_manager_handle: Arc<GlobalRef>,
64+
pool_size: usize,
65+
) -> CometFairMemoryPool {
66+
Self {
67+
task_memory_manager_handle,
68+
pool_size,
69+
state: Mutex::new(CometFairPoolState { used: 0, num: 0 }),
70+
}
71+
}
72+
73+
fn acquire(&self, additional: usize) -> CometResult<i64> {
74+
let mut env = JVMClasses::get_env()?;
75+
let handle = self.task_memory_manager_handle.as_obj();
76+
unsafe {
77+
jni_call!(&mut env,
78+
comet_task_memory_manager(handle).acquire_memory(additional as i64) -> i64)
79+
}
80+
}
81+
82+
fn release(&self, size: usize) -> CometResult<()> {
83+
let mut env = JVMClasses::get_env()?;
84+
let handle = self.task_memory_manager_handle.as_obj();
85+
unsafe {
86+
jni_call!(&mut env, comet_task_memory_manager(handle).release_memory(size as i64) -> ())
87+
}
88+
}
89+
}
90+
91+
unsafe impl Send for CometFairMemoryPool {}
92+
unsafe impl Sync for CometFairMemoryPool {}
93+
94+
impl MemoryPool for CometFairMemoryPool {
95+
fn register(&self, _: &MemoryConsumer) {
96+
let mut state = self.state.lock();
97+
state.num = state
98+
.num
99+
.checked_add(1)
100+
.expect("unexpected amount of register happened");
101+
}
102+
103+
fn unregister(&self, _: &MemoryConsumer) {
104+
let mut state = self.state.lock();
105+
state.num = state
106+
.num
107+
.checked_sub(1)
108+
.expect("unexpected amount of unregister happened");
109+
}
110+
111+
fn grow(&self, reservation: &MemoryReservation, additional: usize) {
112+
self.try_grow(reservation, additional).unwrap();
113+
}
114+
115+
fn shrink(&self, reservation: &MemoryReservation, subtractive: usize) {
116+
if subtractive > 0 {
117+
let mut state = self.state.lock();
118+
let size = reservation.size();
119+
if size < subtractive {
120+
panic!("Failed to release {subtractive} bytes where only {size} bytes reserved")
121+
}
122+
self.release(subtractive)
123+
.unwrap_or_else(|_| panic!("Failed to release {} bytes", subtractive));
124+
state.used = state.used.checked_sub(subtractive).unwrap();
125+
}
126+
}
127+
128+
fn try_grow(
129+
&self,
130+
reservation: &MemoryReservation,
131+
additional: usize,
132+
) -> Result<(), DataFusionError> {
133+
if additional > 0 {
134+
let mut state = self.state.lock();
135+
let num = state.num;
136+
let limit = self.pool_size.checked_div(num).unwrap();
137+
let size = reservation.size();
138+
if limit < size + additional {
139+
return resources_err!(
140+
"Failed to acquire {additional} bytes where {size} bytes already reserved and the fair limit is {limit} bytes, {num} registered"
141+
);
142+
}
143+
144+
let acquired = self.acquire(additional)?;
145+
// If the number of bytes we acquired is less than the requested, return an error,
146+
// and hopefully will trigger spilling from the caller side.
147+
if acquired < additional as i64 {
148+
// Release the acquired bytes before throwing error
149+
self.release(acquired as usize)?;
150+
151+
return resources_err!(
152+
"Failed to acquire {} bytes, only got {} bytes. Reserved: {} bytes",
153+
additional,
154+
acquired,
155+
state.used
156+
);
157+
}
158+
state.used = state.used.checked_add(additional).unwrap();
159+
}
160+
Ok(())
161+
}
162+
163+
fn reserved(&self) -> usize {
164+
self.state.lock().used
165+
}
166+
}

native/core/src/execution/jni_api.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ use std::num::NonZeroUsize;
6363
use std::sync::Mutex;
6464
use tokio::runtime::Runtime;
6565

66+
use crate::execution::fair_memory_pool::CometFairMemoryPool;
6667
use crate::execution::operators::ScanExec;
6768
use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec};
6869
use crate::execution::spark_plan::SparkPlan;
@@ -108,6 +109,7 @@ struct ExecutionContext {
108109
#[derive(PartialEq, Eq)]
109110
enum MemoryPoolType {
110111
Unified,
112+
FairUnified,
111113
Greedy,
112114
FairSpill,
113115
GreedyTaskShared,
@@ -291,11 +293,14 @@ fn parse_memory_pool_config(
291293
memory_limit: i64,
292294
memory_limit_per_task: i64,
293295
) -> CometResult<MemoryPoolConfig> {
296+
let pool_size = memory_limit as usize;
294297
let memory_pool_config = if use_unified_memory_manager {
295-
MemoryPoolConfig::new(MemoryPoolType::Unified, 0)
298+
match memory_pool_type.as_str() {
299+
"fair_unified" => MemoryPoolConfig::new(MemoryPoolType::FairUnified, pool_size),
300+
_ => MemoryPoolConfig::new(MemoryPoolType::Unified, 0),
301+
}
296302
} else {
297303
// Use the memory pool from DF
298-
let pool_size = memory_limit as usize;
299304
let pool_size_per_task = memory_limit_per_task as usize;
300305
match memory_pool_type.as_str() {
301306
"fair_spill_task_shared" => {
@@ -333,6 +338,12 @@ fn create_memory_pool(
333338
let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
334339
Arc::new(memory_pool)
335340
}
341+
MemoryPoolType::FairUnified => {
342+
// Set Comet fair memory pool for native
343+
let memory_pool =
344+
CometFairMemoryPool::new(comet_task_memory_manager, memory_pool_config.pool_size);
345+
Arc::new(memory_pool)
346+
}
336347
MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new(
337348
GreedyMemoryPool::new(memory_pool_config.pool_size),
338349
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),

native/core/src/execution/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub(crate) mod util;
2929
pub use datafusion_comet_spark_expr::timezone;
3030
pub(crate) mod utils;
3131

32+
mod fair_memory_pool;
3233
mod memory_pool;
3334
pub use memory_pool::*;
3435

spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class CometTPCHQuerySuite extends QueryTest with TPCBase with ShimCometTPCHQuery
9393
conf.set(CometConf.COMET_SHUFFLE_MODE.key, "jvm")
9494
conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
9595
conf.set(MEMORY_OFFHEAP_SIZE.key, "2g")
96+
conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g")
9697
}
9798

9899
protected override def createSparkSession: TestSparkSession = {

0 commit comments

Comments
 (0)