Skip to content

Commit fe62f1f

Browse files
committed
fix(cuda_helpers): fix data race in set_shmem_of_kernel and add docstring
1 parent 8e90a94 commit fe62f1f

2 files changed

Lines changed: 57 additions & 25 deletions

File tree

cpp/src/utilities/cuda_helpers.cuh

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <thrust/host_vector.h>
1313
#include <thrust/tuple.h>
1414
#include <mutex>
15+
#include <shared_mutex>
1516
#include <raft/core/device_span.hpp>
1617
#include <raft/util/cuda_utils.cuh>
1718
#include <raft/util/cudart_utils.hpp>
@@ -175,29 +176,49 @@ HDI To bit_cast(const From& src)
175176
return *(To*)(&src);
176177
}
177178

179+
/**
180+
* @brief Raises the dynamic shared-memory limit for a CUDA kernel, with caching.
181+
*
182+
* Calls cudaFuncSetAttribute(cudaFuncAttributeMaxDynamicSharedMemorySize) only when
183+
* @p dynamic_request_size exceeds the previously set limit for @p function. The
184+
* per-kernel high-water mark is stored in a process-wide cache so that repeated
185+
* calls with the same or smaller sizes are cheap shared-lock reads.
186+
*
187+
* Thread safety: safe to call concurrently from multiple host threads.
188+
*
189+
* @param function Host pointer to the __global__ kernel function.
190+
* @param dynamic_request_size Requested dynamic shared memory in bytes.
191+
* A value of 0 is a no-op and always returns true.
192+
* @return true if the attribute was successfully set (or was already sufficient).
193+
* @return false if cudaFuncSetAttribute failed (e.g. size exceeds device limit);
194+
* the sticky CUDA error is consumed so it cannot surface later.
195+
*/
178196
template <typename Function>
179197
inline bool set_shmem_of_kernel(Function* function, size_t dynamic_request_size)
180198
{
181-
static std::mutex mtx;
199+
static std::shared_mutex mtx;
182200
static std::unordered_map<Function*, size_t> shmem_sizes;
183201

184202
if (dynamic_request_size != 0) {
185203
dynamic_request_size = raft::alignTo(dynamic_request_size, size_t(1024));
186-
size_t current_size = shmem_sizes[function];
204+
205+
{
206+
std::shared_lock<std::shared_mutex> rlock(mtx);
207+
auto it = shmem_sizes.find(function);
208+
if (it != shmem_sizes.end() && dynamic_request_size <= it->second) { return true; }
209+
}
210+
211+
std::unique_lock<std::shared_mutex> wlock(mtx);
212+
size_t current_size = shmem_sizes.count(function) ? shmem_sizes[function] : 0;
187213
if (dynamic_request_size > current_size) {
188-
std::lock_guard<std::mutex> lock(mtx);
189-
current_size = shmem_sizes[function];
190-
191-
if (dynamic_request_size > current_size) {
192-
auto err = cudaFuncSetAttribute(
193-
function, cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_request_size);
194-
if (err == cudaSuccess) {
195-
shmem_sizes[function] = dynamic_request_size;
196-
return true;
197-
} else {
198-
cudaGetLastError(); // clear sticky error so later RAFT_CHECK_CUDA doesn't catch it
199-
return false;
200-
}
214+
auto err = cudaFuncSetAttribute(
215+
function, cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_request_size);
216+
if (err == cudaSuccess) {
217+
shmem_sizes[function] = dynamic_request_size;
218+
return true;
219+
} else {
220+
cudaGetLastError(); // clear sticky error so later RAFT_CHECK_CUDA doesn't catch it
221+
return false;
201222
}
202223
}
203224
}

cpp/tests/routing/unit_tests/set_shmem_of_kernel.cu

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,58 +14,69 @@
1414
namespace cuopt {
1515
namespace test {
1616

17+
/// @brief Dummy kernel used to test a zero-byte shared-memory request.
1718
__global__ void kernel_zero() {}
19+
/// @brief Dummy kernel used to test a normal (within-limit) shared-memory request.
1820
__global__ void kernel_normal() {}
21+
/// @brief Dummy kernel used to test a too-large shared-memory request (first call).
1922
__global__ void kernel_too_large_a() {}
23+
/// @brief Dummy kernel used to test a too-large shared-memory request (repeated call).
2024
__global__ void kernel_too_large_b() {}
25+
/// @brief Dummy kernel used to verify that a failed request leaves no sticky CUDA error.
2126
__global__ void kernel_sticky_error() {}
2227

23-
// Zero request is a no-op and must return true.
28+
/// @brief Zero request is a no-op and must return true.
2429
TEST(set_shmem_of_kernel, zero_request)
2530
{
2631
EXPECT_TRUE(set_shmem_of_kernel(kernel_zero, 0));
2732
EXPECT_EQ(cudaSuccess, cudaGetLastError());
2833
}
2934

30-
// A modest request well within device limits must succeed.
35+
/// @brief A modest request well within device limits must succeed.
3136
TEST(set_shmem_of_kernel, normal_request)
3237
{
3338
EXPECT_TRUE(set_shmem_of_kernel(kernel_normal, 4096));
3439
EXPECT_EQ(cudaSuccess, cudaGetLastError());
3540
}
3641

37-
// Requesting more shared memory than the device supports must return false.
42+
/// @brief Requesting more shared memory than the device supports must return false.
3843
TEST(set_shmem_of_kernel, too_large_returns_false)
3944
{
4045
int shmem_max{};
41-
cudaDeviceGetAttribute(&shmem_max, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0);
46+
ASSERT_EQ(cudaSuccess,
47+
cudaDeviceGetAttribute(&shmem_max, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0))
48+
<< "cudaDeviceGetAttribute(cudaDevAttrMaxSharedMemoryPerBlockOptin) failed";
4249
size_t too_large = static_cast<size_t>(shmem_max) + 1024;
4350

4451
EXPECT_FALSE(set_shmem_of_kernel(kernel_too_large_a, too_large));
4552
EXPECT_EQ(cudaSuccess, cudaGetLastError());
4653
}
4754

48-
// A second call with the same too-large size must still return false
55+
/// @brief A second call with the same too-large size must still return false.
4956
TEST(set_shmem_of_kernel, cache_not_poisoned_on_failure)
5057
{
5158
int shmem_max{};
52-
cudaDeviceGetAttribute(&shmem_max, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0);
59+
ASSERT_EQ(cudaSuccess,
60+
cudaDeviceGetAttribute(&shmem_max, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0))
61+
<< "cudaDeviceGetAttribute(cudaDevAttrMaxSharedMemoryPerBlockOptin) failed";
5362
size_t too_large = static_cast<size_t>(shmem_max) + 1024;
5463

5564
EXPECT_FALSE(set_shmem_of_kernel(kernel_too_large_b, too_large));
5665
EXPECT_FALSE(set_shmem_of_kernel(kernel_too_large_b, too_large)); // must not return true
5766
EXPECT_EQ(cudaSuccess, cudaGetLastError());
5867
}
5968

60-
// A failed call must not leave a sticky CUDA error that would be caught
61-
// later by an unrelated RAFT_CHECK_CUDA.
69+
/// @brief A failed call must not leave a sticky CUDA error that would be caught
70+
/// later by an unrelated RAFT_CHECK_CUDA.
6271
TEST(set_shmem_of_kernel, no_sticky_error_after_failure)
6372
{
6473
int shmem_max{};
65-
cudaDeviceGetAttribute(&shmem_max, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0);
74+
ASSERT_EQ(cudaSuccess,
75+
cudaDeviceGetAttribute(&shmem_max, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0))
76+
<< "cudaDeviceGetAttribute(cudaDevAttrMaxSharedMemoryPerBlockOptin) failed";
6677
size_t too_large = static_cast<size_t>(shmem_max) + 1024;
6778

68-
set_shmem_of_kernel(kernel_sticky_error, too_large);
79+
EXPECT_FALSE(set_shmem_of_kernel(kernel_sticky_error, too_large)); // confirm failure branch taken
6980
EXPECT_EQ(cudaSuccess, cudaGetLastError());
7081
}
7182

0 commit comments

Comments
 (0)