Skip to content

Commit fc09814

Browse files
committed
merge precision_t kernels and prune dead code
1 parent 9e6e431 commit fc09814

3 files changed

Lines changed: 35 additions & 325 deletions

File tree

pufferlib/src/kernels.cu

Lines changed: 10 additions & 264 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,122 +1023,6 @@ __global__ void sample_logits_kernel(
10231023

10241024

10251025

1026-
// =============================================================================
1027-
// FCMax: Fused FC -> Max kernel
1028-
// Input: x (B, N, D_in), W (D_out, D_in), b (D_out)
1029-
// Output: out (B, D_out) = max_over_N(x @ W.T + b)
1030-
// Each thread computes one (b, d_out) output element
1031-
// N-fold memory bandwidth reduction vs separate FC + Max kernels
1032-
// W and b are always float32 (mixed precision for bf16 activations)
1033-
// =============================================================================
1034-
1035-
__global__ void fc_max_forward_kernel(
1036-
precision_t* __restrict__ out, // (B, D_out)
1037-
int* __restrict__ argmax_indices, // (B, D_out) - which N produced the max
1038-
const precision_t* __restrict__ x, // (B, N, D_in)
1039-
const float* __restrict__ W, // (D_out, D_in) - always float32
1040-
const float* __restrict__ b, // (D_out) - always float32
1041-
int B, int N, int D_in, int D_out
1042-
) {
1043-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1044-
if (idx >= B * D_out) return;
1045-
1046-
int batch = idx / D_out;
1047-
int d_out = idx % D_out;
1048-
1049-
float bias = b[d_out];
1050-
float max_val = -INFINITY;
1051-
int argmax_n = 0;
1052-
1053-
// Iterate over all N points, compute FC output, track max
1054-
for (int n = 0; n < N; n++) {
1055-
float val = bias;
1056-
for (int di = 0; di < D_in; di++) {
1057-
val += to_float(x[batch * N * D_in + n * D_in + di]) * W[d_out * D_in + di];
1058-
}
1059-
if (val > max_val) {
1060-
max_val = val;
1061-
argmax_n = n;
1062-
}
1063-
}
1064-
1065-
out[idx] = from_float(max_val);
1066-
argmax_indices[idx] = argmax_n;
1067-
}
1068-
1069-
// Deterministic backward: three separate kernels with no atomicAdd
1070-
// Each kernel assigns threads so that each output element is written by exactly one thread.
1071-
1072-
// grad_b[d_out] = sum over batch,t of grad_out[batch*D_out + d_out]
1073-
// One thread per d_out, serial sum over batch dimension
1074-
__global__ void fc_max_backward_grad_b_kernel(
1075-
float* __restrict__ grad_b, // (D_out)
1076-
const precision_t* __restrict__ grad_out, // (B, D_out)
1077-
int B, int D_out
1078-
) {
1079-
int d_out = blockIdx.x * blockDim.x + threadIdx.x;
1080-
if (d_out >= D_out) return;
1081-
1082-
float sum = 0.0f;
1083-
for (int b = 0; b < B; b++) {
1084-
sum += to_float(grad_out[b * D_out + d_out]);
1085-
}
1086-
grad_b[d_out] = sum;
1087-
}
1088-
1089-
// grad_W[d_out, d_in] = sum over batch of grad_out[batch, d_out] * x[batch, argmax[batch, d_out], d_in]
1090-
// One thread per (d_out, d_in) pair, serial sum over batch dimension
1091-
__global__ void fc_max_backward_grad_W_kernel(
1092-
float* __restrict__ grad_W, // (D_out, D_in)
1093-
const precision_t* __restrict__ grad_out, // (B, D_out)
1094-
const precision_t* __restrict__ x, // (B, N, D_in)
1095-
const int* __restrict__ argmax_indices, // (B, D_out)
1096-
int B, int N, int D_in, int D_out
1097-
) {
1098-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1099-
if (idx >= D_out * D_in) return;
1100-
1101-
int d_out = idx / D_in;
1102-
int d_in = idx % D_in;
1103-
1104-
float sum = 0.0f;
1105-
for (int b = 0; b < B; b++) {
1106-
float g_out = to_float(grad_out[b * D_out + d_out]);
1107-
int argmax_n = argmax_indices[b * D_out + d_out];
1108-
float x_val = to_float(x[b * N * D_in + argmax_n * D_in + d_in]);
1109-
sum += g_out * x_val;
1110-
}
1111-
grad_W[idx] = sum;
1112-
}
1113-
1114-
// grad_x[batch, n, d_in] = sum over d_out where argmax[batch, d_out]==n of grad_out[batch, d_out] * W[d_out, d_in]
1115-
// One thread per (batch, d_in) pair, serial loop over d_out
1116-
// Each thread writes to its own batch's grad_x — no cross-thread contention
1117-
__global__ void fc_max_backward_grad_x_kernel(
1118-
float* __restrict__ grad_x, // (B, N, D_in)
1119-
const precision_t* __restrict__ grad_out, // (B, D_out)
1120-
const float* __restrict__ W, // (D_out, D_in)
1121-
const int* __restrict__ argmax_indices, // (B, D_out)
1122-
int B, int N, int D_in, int D_out
1123-
) {
1124-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1125-
if (idx >= B * D_in) return;
1126-
1127-
int batch = idx / D_in;
1128-
int d_in = idx % D_in;
1129-
1130-
// Zero this batch's grad_x column first
1131-
for (int n = 0; n < N; n++) {
1132-
grad_x[batch * N * D_in + n * D_in + d_in] = 0.0f;
1133-
}
1134-
1135-
// Accumulate: for each d_out, add grad_out * W to the argmax position
1136-
for (int d_out = 0; d_out < D_out; d_out++) {
1137-
float g_out = to_float(grad_out[batch * D_out + d_out]);
1138-
int argmax_n = argmax_indices[batch * D_out + d_out];
1139-
grad_x[batch * N * D_in + argmax_n * D_in + d_in] += g_out * W[d_out * D_in + d_in];
1140-
}
1141-
}
11421026

11431027

11441028
#define SELECT_COPY_THREADS 256
@@ -1410,39 +1294,11 @@ __global__ void puff_advantage_kernel_scalar(const precision_t* values, const pr
14101294
// Host wrappers live in models.cu
14111295
// ============================================================================
14121296

1413-
__global__ void cast_bf16_to_f32_kernel(float* __restrict__ dst, const __nv_bfloat16* __restrict__ src, int n) {
1414-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1415-
if (idx < n) dst[idx] = __bfloat162float(src[idx]);
1416-
}
1417-
1418-
__global__ void cast_f32_to_bf16_kernel(__nv_bfloat16* __restrict__ dst, const float* __restrict__ src, int n) {
1419-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1420-
if (idx < n) dst[idx] = __float2bfloat16(src[idx]);
1421-
}
1422-
1423-
__global__ void cast_f32_to_bf16_transpose_kernel(__nv_bfloat16* __restrict__ dst, const float* __restrict__ src, int R, int C) {
1424-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1425-
if (idx >= R * C) return;
1426-
dst[(idx % C) * R + idx / C] = __float2bfloat16(src[idx]);
1427-
}
1428-
1429-
__global__ void transpose_f32_kernel(float* __restrict__ dst, const float* __restrict__ src, int R, int C) {
1430-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1431-
if (idx >= R * C) return;
1432-
dst[(idx % C) * R + idx / C] = src[idx];
1433-
}
1434-
14351297
__global__ void cast_f32_to_precision_kernel(precision_t* __restrict__ dst, const float* __restrict__ src, int n) {
14361298
int idx = blockIdx.x * blockDim.x + threadIdx.x;
14371299
if (idx < n) dst[idx] = from_float(src[idx]);
14381300
}
14391301

1440-
__global__ void cast_f32_transpose_to_precision_kernel(precision_t* __restrict__ dst, const float* __restrict__ src, int R, int C) {
1441-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1442-
if (idx >= R * C) return;
1443-
dst[(idx % C) * R + idx / C] = from_float(src[idx]);
1444-
}
1445-
14461302
template <typename T>
14471303
__global__ void transpose_01_kernel(T* __restrict__ dst, const T* __restrict__ src, int A, int B, int C) {
14481304
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -1452,20 +1308,6 @@ __global__ void transpose_01_kernel(T* __restrict__ dst, const T* __restrict__ s
14521308
dst[b * A * C + a * C + c] = src[idx];
14531309
}
14541310

1455-
__global__ void norm_bf16_kernel(float* __restrict__ partials, const __nv_bfloat16* __restrict__ src, int n) {
1456-
__shared__ float sdata[256];
1457-
int tid = threadIdx.x;
1458-
float sum = 0.0f;
1459-
for (int i = blockIdx.x * blockDim.x + tid; i < n; i += blockDim.x * gridDim.x) {
1460-
float v = __bfloat162float(src[i]);
1461-
sum += v * v;
1462-
}
1463-
sdata[tid] = sum;
1464-
__syncthreads();
1465-
for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) sdata[tid] += sdata[tid + s]; __syncthreads(); }
1466-
if (tid == 0) partials[blockIdx.x] = sdata[0];
1467-
}
1468-
14691311
__global__ void norm_f32_kernel(float* __restrict__ partials, const float* __restrict__ src, int n) {
14701312
__shared__ float sdata[256];
14711313
int tid = threadIdx.x;
@@ -1493,18 +1335,6 @@ __global__ void clip_by_norm_f32_kernel(float* __restrict__ dst, const float* __
14931335
if (idx < n) dst[idx] *= clip_coef;
14941336
}
14951337

1496-
__global__ void normalize_bf16_kernel(__nv_bfloat16* __restrict__ dst, const float* __restrict__ norm_ptr, float eps, int n) {
1497-
float inv_norm = 1.0f / fmaxf(sqrtf(*norm_ptr), eps);
1498-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1499-
if (idx < n) dst[idx] = __float2bfloat16(__bfloat162float(dst[idx]) * inv_norm);
1500-
}
1501-
1502-
__global__ void normalize_f32_kernel(float* __restrict__ dst, const float* __restrict__ norm_ptr, float eps, int n) {
1503-
float inv_norm = 1.0f / fmaxf(sqrtf(*norm_ptr), eps);
1504-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1505-
if (idx < n) dst[idx] = dst[idx] * inv_norm;
1506-
}
1507-
15081338
__global__ void norm_precision_kernel(float* __restrict__ partials, const precision_t* __restrict__ src, int n) {
15091339
__shared__ float sdata[256];
15101340
int tid = threadIdx.x;
@@ -1530,17 +1360,6 @@ __global__ void cast_precision_to_f32_kernel(float* __restrict__ dst, const prec
15301360
if (idx < n) dst[idx] = to_float(src[idx]);
15311361
}
15321362

1533-
__global__ void cast_precision_scale_to_f32_kernel(float* __restrict__ dst, const precision_t* __restrict__ src, float scale, int n) {
1534-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1535-
if (idx < n) dst[idx] = to_float(src[idx]) * scale;
1536-
}
1537-
1538-
__global__ void cast_precision_scale_transpose_to_f32_kernel(float* __restrict__ dst, const precision_t* __restrict__ src, float scale, int R, int C) {
1539-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1540-
if (idx >= R * C) return;
1541-
dst[(idx % C) * R + idx / C] = to_float(src[idx]) * scale;
1542-
}
1543-
15441363
// Input: (R, C) f32 → (M, N) precision_t, optionally transposing
15451364
__global__ void cast_f32_to_precision_2d_kernel(precision_t* __restrict__ dst, const float* __restrict__ src, bool do_transpose, int R, int C) {
15461365
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -1557,34 +1376,14 @@ __global__ void cast_precision_scale_to_f32_2d_kernel(float* __restrict__ dst, c
15571376
dst[out_idx] = to_float(src[idx]) * scale;
15581377
}
15591378

1560-
__global__ void fill_bf16_kernel(__nv_bfloat16* __restrict__ dst, __nv_bfloat16 val, int n) {
1561-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1562-
if (idx < n) dst[idx] = val;
1563-
}
1564-
1565-
__global__ void fill_f32_kernel(float* __restrict__ dst, float val, int n) {
1379+
__global__ void fill_precision_kernel(precision_t* __restrict__ dst, precision_t val, int n) {
15661380
int idx = blockIdx.x * blockDim.x + threadIdx.x;
15671381
if (idx < n) dst[idx] = val;
15681382
}
15691383

1570-
__global__ void clamp_bf16_kernel(__nv_bfloat16* __restrict__ dst, float lo, float hi, int n) {
1571-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1572-
if (idx < n) { float v = __bfloat162float(dst[idx]); dst[idx] = __float2bfloat16(fminf(fmaxf(v, lo), hi)); }
1573-
}
1574-
1575-
__global__ void clamp_f32_kernel(float* __restrict__ dst, float lo, float hi, int n) {
1576-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1577-
if (idx < n) { dst[idx] = fminf(fmaxf(dst[idx], lo), hi); }
1578-
}
1579-
1580-
__global__ void scale_f32_kernel(float* __restrict__ dst, float alpha, int n) {
1581-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1582-
if (idx < n) dst[idx] *= alpha;
1583-
}
1584-
1585-
__global__ void axpy_f32_kernel(float* __restrict__ dst, const float* __restrict__ src, float alpha, int n) {
1384+
__global__ void clamp_precision_kernel(precision_t* __restrict__ dst, float lo, float hi, int n) {
15861385
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1587-
if (idx < n) dst[idx] += alpha * src[idx];
1386+
if (idx < n) { float v = to_float(dst[idx]); dst[idx] = from_float(fminf(fmaxf(v, lo), hi)); }
15881387
}
15891388

15901389
// Fused Nesterov momentum: mb = mu*mb + gc; gc = gc + mu*mb
@@ -1597,24 +1396,6 @@ __global__ void nesterov_f32_kernel(float* __restrict__ mb, float* __restrict__
15971396
}
15981397
}
15991398

1600-
__global__ void scale_f32_dev_kernel(float* __restrict__ dst, const float* __restrict__ alpha_ptr, int n) {
1601-
float alpha = *alpha_ptr;
1602-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1603-
if (idx < n) dst[idx] *= alpha;
1604-
}
1605-
1606-
__global__ void axpy_f32_dev_kernel(float* __restrict__ dst, const float* __restrict__ src, const float* __restrict__ alpha_ptr, int n) {
1607-
float alpha = *alpha_ptr;
1608-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1609-
if (idx < n) dst[idx] += alpha * src[idx];
1610-
}
1611-
1612-
__global__ void compute_lr_scalars_kernel(const float* __restrict__ lr, float wd,
1613-
float* __restrict__ neg_lr, float* __restrict__ wd_scale) {
1614-
*neg_lr = -(*lr);
1615-
*wd_scale = 1.0f - (*lr) * wd;
1616-
}
1617-
16181399
// Fused weight update: wb = wb * (1 - lr*wd) - lr * up
16191400
__global__ void muon_weight_update_kernel(float* __restrict__ wb, const float* __restrict__ up,
16201401
const float* __restrict__ lr_ptr, float wd, int n) {
@@ -1626,63 +1407,32 @@ __global__ void muon_weight_update_kernel(float* __restrict__ wb, const float* _
16261407
}
16271408
}
16281409

1629-
__global__ void add_bf16_to_f32_kernel(float* __restrict__ dst, const __nv_bfloat16* __restrict__ src, int n) {
1410+
__global__ void add_precision_to_f32_kernel(float* __restrict__ dst, const precision_t* __restrict__ src, int n) {
16301411
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1631-
if (idx < n) dst[idx] += __bfloat162float(src[idx]);
1412+
if (idx < n) dst[idx] += to_float(src[idx]);
16321413
}
16331414

16341415
__global__ void add_precision_kernel(precision_t* __restrict__ dst, const precision_t* __restrict__ src, int n) {
16351416
int idx = blockIdx.x * blockDim.x + threadIdx.x;
16361417
if (idx < n) dst[idx] = from_float(to_float(dst[idx]) + to_float(src[idx]));
16371418
}
16381419

1639-
__global__ void add_f32_kernel(float* __restrict__ dst, const float* __restrict__ src, int n) {
1640-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1641-
if (idx < n) dst[idx] += src[idx];
1642-
}
1643-
1644-
__global__ void sum_rows_add_kernel(float* __restrict__ dst, const float* __restrict__ src, int R, int C) {
1645-
int col = blockIdx.x * blockDim.x + threadIdx.x;
1646-
if (col >= C) return;
1647-
float sum = 0.0f;
1648-
for (int r = 0; r < R; r++) sum += src[r * C + col];
1649-
dst[col] += sum;
1650-
}
1651-
16521420
// Sum f32 rows → bf16 output (set, not accumulate)
1653-
__global__ void sum_rows_to_bf16_kernel(__nv_bfloat16* __restrict__ dst, const float* __restrict__ src, int R, int C) {
1654-
int col = blockIdx.x * blockDim.x + threadIdx.x;
1655-
if (col >= C) return;
1656-
float sum = 0.0f;
1657-
for (int r = 0; r < R; r++) sum += src[r * C + col];
1658-
dst[col] = __float2bfloat16(sum);
1659-
}
1660-
1661-
// Sum f32 rows → f32 output (set, not accumulate)
1662-
__global__ void sum_rows_to_f32_kernel(float* __restrict__ dst, const float* __restrict__ src, int R, int C) {
1421+
__global__ void sum_rows_to_precision_kernel(precision_t* __restrict__ dst, const float* __restrict__ src, int R, int C) {
16631422
int col = blockIdx.x * blockDim.x + threadIdx.x;
16641423
if (col >= C) return;
16651424
float sum = 0.0f;
16661425
for (int r = 0; r < R; r++) sum += src[r * C + col];
1667-
dst[col] = sum;
1426+
dst[col] = from_float(sum);
16681427
}
16691428

1670-
__global__ void assemble_decoder_grad_bf16_kernel(
1671-
__nv_bfloat16* __restrict__ dst, const float* __restrict__ grad_logits,
1429+
__global__ void assemble_decoder_grad_kernel(
1430+
precision_t* __restrict__ dst, const float* __restrict__ grad_logits,
16721431
const float* __restrict__ grad_value, int B_TT, int od, int od_plus_1) {
16731432
int idx = blockIdx.x * blockDim.x + threadIdx.x;
16741433
if (idx >= B_TT * od_plus_1) return;
16751434
int row = idx / od_plus_1, col = idx % od_plus_1;
1676-
dst[idx] = __float2bfloat16((col < od) ? grad_logits[row * od + col] : grad_value[row]);
1677-
}
1678-
1679-
__global__ void assemble_decoder_grad_f32_kernel(
1680-
float* __restrict__ dst, const float* __restrict__ grad_logits,
1681-
const float* __restrict__ grad_value, int B_TT, int od, int od_plus_1) {
1682-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
1683-
if (idx >= B_TT * od_plus_1) return;
1684-
int row = idx / od_plus_1, col = idx % od_plus_1;
1685-
dst[idx] = (col < od) ? grad_logits[row * od + col] : grad_value[row];
1435+
dst[idx] = from_float((col < od) ? grad_logits[row * od + col] : grad_value[row]);
16861436
}
16871437

16881438
__global__ void var_mean_kernel(const float* __restrict__ src, float* __restrict__ var_out,
@@ -1705,10 +1455,6 @@ __global__ void var_mean_kernel(const float* __restrict__ src, float* __restrict
17051455
if (tid == 0) *var_out = sdata[0] / (float)(n - 1);
17061456
}
17071457

1708-
__global__ void add_scalar_kernel(float* __restrict__ ptr, float val) {
1709-
*ptr += val;
1710-
}
1711-
17121458
__global__ void index_copy_kernel(char* __restrict__ dst, const int64_t* __restrict__ idx,
17131459
const char* __restrict__ src, int num_idx, int row_bytes) {
17141460
int i = blockIdx.x * blockDim.x + threadIdx.x;

0 commit comments

Comments
 (0)