Skip to content

Commit 6308bba

Browse files
authored
[STF] Use opaque struct pointers for C API handles (#8205)
Replace `typedef void*` with `typedef struct <name>_t* <name>` for all opaque handle types in the STF C API. This provides type safety at the C level, preventing accidental mixing of different handle types. The naming convention (e.g. stf_ctx_handle_t / stf_ctx_handle) matches the Cython declarations in the stf_c_api branch for consistency. Made-with: Cursor
1 parent 4816308 commit 6308bba

2 files changed

Lines changed: 46 additions & 46 deletions

File tree

  • c/experimental/stf

c/experimental/stf/include/cccl/c/experimental/stf/stf.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ typedef struct stf_dim4
235235
typedef stf_pos4 (*stf_get_executor_fn)(stf_pos4 data_coords, stf_dim4 data_dims, stf_dim4 grid_dims);
236236

237237
//! \brief Opaque handle for an execution place grid (e.g. one place per stream).
238-
typedef void* stf_exec_place_grid_handle;
238+
typedef struct stf_exec_place_grid_handle_t* stf_exec_place_grid_handle;
239239

240240
//! \}
241241

@@ -414,31 +414,31 @@ void stf_exec_place_grid_destroy(stf_exec_place_grid_handle grid);
414414
//! Context stores the state of the STF library and serves as entry point for all API calls.
415415
//! Must be created with stf_ctx_create() or stf_ctx_create_graph() and destroyed with stf_ctx_finalize().
416416

417-
typedef void* stf_ctx_handle;
417+
typedef struct stf_ctx_handle_t* stf_ctx_handle;
418418

419419
//!
420420
//! \brief Opaque handle for logical data
421421
//!
422422
//! Represents abstract data that may exist in multiple memory locations.
423423
//! Created with stf_logical_data() or stf_logical_data_empty() and destroyed with stf_logical_data_destroy().
424424

425-
typedef void* stf_logical_data_handle;
425+
typedef struct stf_logical_data_handle_t* stf_logical_data_handle;
426426

427427
//!
428428
//! \brief Opaque handle for task
429429
//!
430430
//! Represents a computational task that operates on logical data.
431431
//! Created with stf_task_create() and destroyed with stf_task_destroy().
432432

433-
typedef void* stf_task_handle;
433+
typedef struct stf_task_handle_t* stf_task_handle;
434434

435435
//!
436436
//! \brief Opaque handle for CUDA kernel task
437437
//!
438438
//! Specialized task optimized for CUDA kernel execution.
439439
//! Created with stf_cuda_kernel_create() and destroyed with stf_cuda_kernel_destroy().
440440

441-
typedef void* stf_cuda_kernel_handle;
441+
typedef struct stf_cuda_kernel_handle_t* stf_cuda_kernel_handle;
442442

443443
//! \}
444444

c/experimental/stf/src/stf.cu

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ static data_place to_data_place(const stf_data_place* data_p)
5858
{
5959
return data_place::invalid();
6060
}
61-
auto grid_ptr = static_cast<exec_place*>(grid_handle);
61+
auto* grid_ptr = reinterpret_cast<exec_place*>(grid_handle);
6262
// Layout-compatible: pass C mapper directly so the runtime calls it
6363
partition_fn_t cpp_mapper = reinterpret_cast<partition_fn_t>(mapper);
6464
return data_place::composite(cpp_mapper, *grid_ptr);
@@ -73,27 +73,27 @@ static data_place to_data_place(const stf_data_place* data_p)
7373
void stf_ctx_create(stf_ctx_handle* ctx)
7474
{
7575
_CCCL_ASSERT(ctx != nullptr, "context handle pointer must not be null");
76-
*ctx = new context{};
76+
*ctx = reinterpret_cast<stf_ctx_handle>(new context{});
7777
}
7878

7979
void stf_ctx_create_graph(stf_ctx_handle* ctx)
8080
{
8181
_CCCL_ASSERT(ctx != nullptr, "context handle pointer must not be null");
82-
*ctx = new context{graph_ctx()};
82+
*ctx = reinterpret_cast<stf_ctx_handle>(new context{graph_ctx()});
8383
}
8484

8585
void stf_ctx_finalize(stf_ctx_handle ctx)
8686
{
8787
_CCCL_ASSERT(ctx != nullptr, "context handle must not be null");
88-
auto* context_ptr = static_cast<context*>(ctx);
88+
auto* context_ptr = reinterpret_cast<context*>(ctx);
8989
context_ptr->finalize();
9090
delete context_ptr;
9191
}
9292

9393
cudaStream_t stf_fence(stf_ctx_handle ctx)
9494
{
9595
_CCCL_ASSERT(ctx != nullptr, "context handle must not be null");
96-
auto* context_ptr = static_cast<context*>(ctx);
96+
auto* context_ptr = reinterpret_cast<context*>(ctx);
9797
return context_ptr->fence();
9898
}
9999

@@ -109,7 +109,7 @@ void stf_logical_data_with_place(
109109
_CCCL_ASSERT(ctx != nullptr, "context handle pointer must not be null");
110110
_CCCL_ASSERT(ld != nullptr, "logical data handle pointer must not be null");
111111

112-
auto* context_ptr = static_cast<context*>(ctx);
112+
auto* context_ptr = reinterpret_cast<context*>(ctx);
113113

114114
// Convert C data_place to C++ data_place
115115
cuda::experimental::stf::data_place cpp_dplace;
@@ -141,23 +141,23 @@ void stf_logical_data_with_place(
141141
auto ld_typed = context_ptr->logical_data(make_slice((char*) addr, sz), cpp_dplace);
142142

143143
// Store the logical_data_untyped directly as opaque pointer
144-
*ld = new logical_data_untyped{ld_typed};
144+
*ld = reinterpret_cast<stf_logical_data_handle>(new logical_data_untyped{ld_typed});
145145
}
146146

147147
void stf_logical_data_set_symbol(stf_logical_data_handle ld, const char* symbol)
148148
{
149149
_CCCL_ASSERT(ld != nullptr, "logical data handle must not be null");
150150
_CCCL_ASSERT(symbol != nullptr, "symbol string must not be null");
151151

152-
auto* ld_ptr = static_cast<logical_data_untyped*>(ld);
152+
auto* ld_ptr = reinterpret_cast<logical_data_untyped*>(ld);
153153
ld_ptr->set_symbol(symbol);
154154
}
155155

156156
void stf_logical_data_destroy(stf_logical_data_handle ld)
157157
{
158158
_CCCL_ASSERT(ld != nullptr, "logical data handle must not be null");
159159

160-
auto* ld_ptr = static_cast<logical_data_untyped*>(ld);
160+
auto* ld_ptr = reinterpret_cast<logical_data_untyped*>(ld);
161161
delete ld_ptr;
162162
}
163163

@@ -166,18 +166,18 @@ void stf_logical_data_empty(stf_ctx_handle ctx, size_t length, stf_logical_data_
166166
_CCCL_ASSERT(ctx != nullptr, "context handle must not be null");
167167
_CCCL_ASSERT(to != nullptr, "logical data output pointer must not be null");
168168

169-
auto* context_ptr = static_cast<context*>(ctx);
169+
auto* context_ptr = reinterpret_cast<context*>(ctx);
170170
auto ld_typed = context_ptr->logical_data(shape_of<slice<char>>(length));
171-
*to = new logical_data_untyped{ld_typed};
171+
*to = reinterpret_cast<stf_logical_data_handle>(new logical_data_untyped{ld_typed});
172172
}
173173

174174
void stf_token(stf_ctx_handle ctx, stf_logical_data_handle* ld)
175175
{
176176
_CCCL_ASSERT(ctx != nullptr, "context handle must not be null");
177177
_CCCL_ASSERT(ld != nullptr, "logical data handle output pointer must not be null");
178178

179-
auto* context_ptr = static_cast<context*>(ctx);
180-
*ld = new logical_data_untyped{context_ptr->token()};
179+
auto* context_ptr = reinterpret_cast<context*>(ctx);
180+
*ld = reinterpret_cast<stf_logical_data_handle>(new logical_data_untyped{context_ptr->token()});
181181
}
182182

183183
/* Convert the C-API stf_exec_place to a C++ exec_place object */
@@ -204,16 +204,16 @@ void stf_task_create(stf_ctx_handle ctx, stf_task_handle* t)
204204
_CCCL_ASSERT(ctx != nullptr, "context handle must not be null");
205205
_CCCL_ASSERT(t != nullptr, "task handle output pointer must not be null");
206206

207-
auto* context_ptr = static_cast<context*>(ctx);
208-
*t = new context::unified_task<>{context_ptr->task()};
207+
auto* context_ptr = reinterpret_cast<context*>(ctx);
208+
*t = reinterpret_cast<stf_task_handle>(new context::unified_task<>{context_ptr->task()});
209209
}
210210

211211
void stf_task_set_exec_place(stf_task_handle t, stf_exec_place* exec_p)
212212
{
213213
_CCCL_ASSERT(t != nullptr, "task handle must not be null");
214214
_CCCL_ASSERT(exec_p != nullptr, "exec_place pointer must not be null");
215215

216-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
216+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
217217
task_ptr->set_exec_place(to_exec_place(exec_p));
218218
}
219219

@@ -222,7 +222,7 @@ void stf_task_set_symbol(stf_task_handle t, const char* symbol)
222222
_CCCL_ASSERT(t != nullptr, "task handle must not be null");
223223
_CCCL_ASSERT(symbol != nullptr, "symbol string must not be null");
224224

225-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
225+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
226226
task_ptr->set_symbol(symbol);
227227
}
228228

@@ -231,8 +231,8 @@ void stf_task_add_dep(stf_task_handle t, stf_logical_data_handle ld, stf_access_
231231
_CCCL_ASSERT(t != nullptr, "task handle must not be null");
232232
_CCCL_ASSERT(ld != nullptr, "logical data handle must not be null");
233233

234-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
235-
auto* ld_ptr = static_cast<logical_data_untyped*>(ld);
234+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
235+
auto* ld_ptr = reinterpret_cast<logical_data_untyped*>(ld);
236236
task_ptr->add_deps(task_dep_untyped(*ld_ptr, access_mode(m)));
237237
}
238238

@@ -243,16 +243,16 @@ void stf_task_add_dep_with_dplace(
243243
_CCCL_ASSERT(ld != nullptr, "logical data handle must not be null");
244244
_CCCL_ASSERT(data_p != nullptr, "data_place pointer must not be null");
245245

246-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
247-
auto* ld_ptr = static_cast<logical_data_untyped*>(ld);
246+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
247+
auto* ld_ptr = reinterpret_cast<logical_data_untyped*>(ld);
248248
task_ptr->add_deps(task_dep_untyped(*ld_ptr, access_mode(m), to_data_place(data_p)));
249249
}
250250

251251
void* stf_task_get(stf_task_handle t, int index)
252252
{
253253
_CCCL_ASSERT(t != nullptr, "task handle must not be null");
254254

255-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
255+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
256256
auto s = task_ptr->template get<slice<const char>>(index);
257257
return (void*) s.data_handle();
258258
}
@@ -261,39 +261,39 @@ void stf_task_start(stf_task_handle t)
261261
{
262262
_CCCL_ASSERT(t != nullptr, "task handle must not be null");
263263

264-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
264+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
265265
task_ptr->start();
266266
}
267267

268268
void stf_task_end(stf_task_handle t)
269269
{
270270
_CCCL_ASSERT(t != nullptr, "task handle must not be null");
271271

272-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
272+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
273273
task_ptr->end();
274274
}
275275

276276
void stf_task_enable_capture(stf_task_handle t)
277277
{
278278
_CCCL_ASSERT(t != nullptr, "task handle must not be null");
279279

280-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
280+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
281281
task_ptr->enable_capture();
282282
}
283283

284284
CUstream stf_task_get_custream(stf_task_handle t)
285285
{
286286
_CCCL_ASSERT(t != nullptr, "task handle must not be null");
287287

288-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
288+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
289289
return static_cast<CUstream>(task_ptr->get_stream());
290290
}
291291

292292
void stf_task_destroy(stf_task_handle t)
293293
{
294294
_CCCL_ASSERT(t != nullptr, "task handle must not be null");
295295

296-
auto* task_ptr = static_cast<context::unified_task<>*>(t);
296+
auto* task_ptr = reinterpret_cast<context::unified_task<>*>(t);
297297
delete task_ptr;
298298
}
299299

@@ -319,9 +319,9 @@ void stf_cuda_kernel_create(stf_ctx_handle ctx, stf_cuda_kernel_handle* k)
319319
_CCCL_ASSERT(ctx != nullptr, "context handle must not be null");
320320
_CCCL_ASSERT(k != nullptr, "cuda kernel handle output pointer must not be null");
321321

322-
auto* context_ptr = static_cast<context*>(ctx);
322+
auto* context_ptr = reinterpret_cast<context*>(ctx);
323323
using kernel_type = decltype(context_ptr->cuda_kernel());
324-
*k = new kernel_type{context_ptr->cuda_kernel()};
324+
*k = reinterpret_cast<stf_cuda_kernel_handle>(new kernel_type{context_ptr->cuda_kernel()});
325325
}
326326

327327
void stf_cuda_kernel_set_exec_place(stf_cuda_kernel_handle k, stf_exec_place* exec_p)
@@ -330,7 +330,7 @@ void stf_cuda_kernel_set_exec_place(stf_cuda_kernel_handle k, stf_exec_place* ex
330330
_CCCL_ASSERT(exec_p != nullptr, "exec_place pointer must not be null");
331331

332332
using kernel_type = decltype(::std::declval<context>().cuda_kernel());
333-
auto* kernel_ptr = static_cast<kernel_type*>(k);
333+
auto* kernel_ptr = reinterpret_cast<kernel_type*>(k);
334334
kernel_ptr->set_exec_place(to_exec_place(exec_p));
335335
}
336336

@@ -340,7 +340,7 @@ void stf_cuda_kernel_set_symbol(stf_cuda_kernel_handle k, const char* symbol)
340340
_CCCL_ASSERT(symbol != nullptr, "symbol string must not be null");
341341

342342
using kernel_type = decltype(::std::declval<context>().cuda_kernel());
343-
auto* kernel_ptr = static_cast<kernel_type*>(k);
343+
auto* kernel_ptr = reinterpret_cast<kernel_type*>(k);
344344
kernel_ptr->set_symbol(symbol);
345345
}
346346

@@ -350,8 +350,8 @@ void stf_cuda_kernel_add_dep(stf_cuda_kernel_handle k, stf_logical_data_handle l
350350
_CCCL_ASSERT(ld != nullptr, "logical data handle must not be null");
351351

352352
using kernel_type = decltype(::std::declval<context>().cuda_kernel());
353-
auto* kernel_ptr = static_cast<kernel_type*>(k);
354-
auto* ld_ptr = static_cast<logical_data_untyped*>(ld);
353+
auto* kernel_ptr = reinterpret_cast<kernel_type*>(k);
354+
auto* ld_ptr = reinterpret_cast<logical_data_untyped*>(ld);
355355
kernel_ptr->add_deps(task_dep_untyped(*ld_ptr, access_mode(m)));
356356
}
357357

@@ -360,7 +360,7 @@ void stf_cuda_kernel_start(stf_cuda_kernel_handle k)
360360
_CCCL_ASSERT(k != nullptr, "cuda kernel handle must not be null");
361361

362362
using kernel_type = decltype(::std::declval<context>().cuda_kernel());
363-
auto* kernel_ptr = static_cast<kernel_type*>(k);
363+
auto* kernel_ptr = reinterpret_cast<kernel_type*>(k);
364364
kernel_ptr->start();
365365
}
366366

@@ -376,7 +376,7 @@ void stf_cuda_kernel_add_desc_cufunc(
376376
_CCCL_ASSERT(k != nullptr, "cuda kernel handle must not be null");
377377

378378
using kernel_type = decltype(::std::declval<context>().cuda_kernel());
379-
auto* kernel_ptr = static_cast<kernel_type*>(k);
379+
auto* kernel_ptr = reinterpret_cast<kernel_type*>(k);
380380

381381
cuda_kernel_desc desc;
382382
desc.configure_raw(cufunc, grid_dim_, block_dim_, shared_mem_, arg_cnt, args);
@@ -388,7 +388,7 @@ void* stf_cuda_kernel_get_arg(stf_cuda_kernel_handle k, int index)
388388
_CCCL_ASSERT(k != nullptr, "cuda kernel handle must not be null");
389389

390390
using kernel_type = decltype(::std::declval<context>().cuda_kernel());
391-
auto* kernel_ptr = static_cast<kernel_type*>(k);
391+
auto* kernel_ptr = reinterpret_cast<kernel_type*>(k);
392392
auto s = kernel_ptr->template get<slice<const char>>(index);
393393
return (void*) (s.data_handle());
394394
}
@@ -398,7 +398,7 @@ void stf_cuda_kernel_end(stf_cuda_kernel_handle k)
398398
_CCCL_ASSERT(k != nullptr, "cuda kernel handle must not be null");
399399

400400
using kernel_type = decltype(::std::declval<context>().cuda_kernel());
401-
auto kernel_ptr = static_cast<kernel_type*>(k);
401+
auto kernel_ptr = reinterpret_cast<kernel_type*>(k);
402402
kernel_ptr->end();
403403
}
404404

@@ -407,7 +407,7 @@ void stf_cuda_kernel_destroy(stf_cuda_kernel_handle t)
407407
_CCCL_ASSERT(t != nullptr, "cuda kernel handle must not be null");
408408

409409
using kernel_type = decltype(::std::declval<context>().cuda_kernel());
410-
auto* kernel_ptr = static_cast<kernel_type*>(t);
410+
auto* kernel_ptr = reinterpret_cast<kernel_type*>(t);
411411
delete kernel_ptr;
412412
}
413413

@@ -425,7 +425,7 @@ stf_exec_place_grid_handle stf_exec_place_grid_from_devices(const int* device_id
425425
{
426426
places.push_back(exec_place::device(device_ids[i]));
427427
}
428-
return new exec_place(make_grid(::std::move(places)));
428+
return reinterpret_cast<stf_exec_place_grid_handle>(new exec_place(make_grid(::std::move(places))));
429429
}
430430

431431
stf_exec_place_grid_handle
@@ -441,14 +441,14 @@ stf_exec_place_grid_create(const stf_exec_place* places, size_t count, const stf
441441
exec_place grid = (grid_dims != nullptr)
442442
? make_grid(::std::move(cpp_places), dim4(grid_dims->x, grid_dims->y, grid_dims->z, grid_dims->t))
443443
: make_grid(::std::move(cpp_places));
444-
return new exec_place(::std::move(grid));
444+
return reinterpret_cast<stf_exec_place_grid_handle>(new exec_place(::std::move(grid)));
445445
}
446446

447447
void stf_exec_place_grid_destroy(stf_exec_place_grid_handle grid)
448448
{
449449
if (grid != nullptr)
450450
{
451-
delete static_cast<exec_place*>(grid);
451+
delete reinterpret_cast<exec_place*>(grid);
452452
}
453453
}
454454

0 commit comments

Comments
 (0)