@@ -1023,122 +1023,6 @@ __global__ void sample_logits_kernel(
10231023
10241024
10251025
1026- // =============================================================================
1027- // FCMax: Fused FC -> Max kernel
1028- // Input: x (B, N, D_in), W (D_out, D_in), b (D_out)
1029- // Output: out (B, D_out) = max_over_N(x @ W.T + b)
1030- // Each thread computes one (b, d_out) output element
1031- // N-fold memory bandwidth reduction vs separate FC + Max kernels
1032- // W and b are always float32 (mixed precision for bf16 activations)
1033- // =============================================================================
1034-
1035- __global__ void fc_max_forward_kernel (
1036- precision_t * __restrict__ out, // (B, D_out)
1037- int * __restrict__ argmax_indices, // (B, D_out) - which N produced the max
1038- const precision_t * __restrict__ x, // (B, N, D_in)
1039- const float * __restrict__ W, // (D_out, D_in) - always float32
1040- const float * __restrict__ b, // (D_out) - always float32
1041- int B, int N, int D_in, int D_out
1042- ) {
1043- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1044- if (idx >= B * D_out) return ;
1045-
1046- int batch = idx / D_out;
1047- int d_out = idx % D_out;
1048-
1049- float bias = b[d_out];
1050- float max_val = -INFINITY;
1051- int argmax_n = 0 ;
1052-
1053- // Iterate over all N points, compute FC output, track max
1054- for (int n = 0 ; n < N; n++) {
1055- float val = bias;
1056- for (int di = 0 ; di < D_in; di++) {
1057- val += to_float (x[batch * N * D_in + n * D_in + di]) * W[d_out * D_in + di];
1058- }
1059- if (val > max_val) {
1060- max_val = val;
1061- argmax_n = n;
1062- }
1063- }
1064-
1065- out[idx] = from_float (max_val);
1066- argmax_indices[idx] = argmax_n;
1067- }
1068-
1069- // Deterministic backward: three separate kernels with no atomicAdd
1070- // Each kernel assigns threads so that each output element is written by exactly one thread.
1071-
1072- // grad_b[d_out] = sum over batch,t of grad_out[batch*D_out + d_out]
1073- // One thread per d_out, serial sum over batch dimension
1074- __global__ void fc_max_backward_grad_b_kernel (
1075- float * __restrict__ grad_b, // (D_out)
1076- const precision_t * __restrict__ grad_out, // (B, D_out)
1077- int B, int D_out
1078- ) {
1079- int d_out = blockIdx .x * blockDim .x + threadIdx .x ;
1080- if (d_out >= D_out) return ;
1081-
1082- float sum = 0 .0f ;
1083- for (int b = 0 ; b < B; b++) {
1084- sum += to_float (grad_out[b * D_out + d_out]);
1085- }
1086- grad_b[d_out] = sum;
1087- }
1088-
1089- // grad_W[d_out, d_in] = sum over batch of grad_out[batch, d_out] * x[batch, argmax[batch, d_out], d_in]
1090- // One thread per (d_out, d_in) pair, serial sum over batch dimension
1091- __global__ void fc_max_backward_grad_W_kernel (
1092- float * __restrict__ grad_W, // (D_out, D_in)
1093- const precision_t * __restrict__ grad_out, // (B, D_out)
1094- const precision_t * __restrict__ x, // (B, N, D_in)
1095- const int * __restrict__ argmax_indices, // (B, D_out)
1096- int B, int N, int D_in, int D_out
1097- ) {
1098- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1099- if (idx >= D_out * D_in) return ;
1100-
1101- int d_out = idx / D_in;
1102- int d_in = idx % D_in;
1103-
1104- float sum = 0 .0f ;
1105- for (int b = 0 ; b < B; b++) {
1106- float g_out = to_float (grad_out[b * D_out + d_out]);
1107- int argmax_n = argmax_indices[b * D_out + d_out];
1108- float x_val = to_float (x[b * N * D_in + argmax_n * D_in + d_in]);
1109- sum += g_out * x_val;
1110- }
1111- grad_W[idx] = sum;
1112- }
1113-
1114- // grad_x[batch, n, d_in] = sum over d_out where argmax[batch, d_out]==n of grad_out[batch, d_out] * W[d_out, d_in]
1115- // One thread per (batch, d_in) pair, serial loop over d_out
1116- // Each thread writes to its own batch's grad_x — no cross-thread contention
1117- __global__ void fc_max_backward_grad_x_kernel (
1118- float * __restrict__ grad_x, // (B, N, D_in)
1119- const precision_t * __restrict__ grad_out, // (B, D_out)
1120- const float * __restrict__ W, // (D_out, D_in)
1121- const int * __restrict__ argmax_indices, // (B, D_out)
1122- int B, int N, int D_in, int D_out
1123- ) {
1124- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1125- if (idx >= B * D_in) return ;
1126-
1127- int batch = idx / D_in;
1128- int d_in = idx % D_in;
1129-
1130- // Zero this batch's grad_x column first
1131- for (int n = 0 ; n < N; n++) {
1132- grad_x[batch * N * D_in + n * D_in + d_in] = 0 .0f ;
1133- }
1134-
1135- // Accumulate: for each d_out, add grad_out * W to the argmax position
1136- for (int d_out = 0 ; d_out < D_out; d_out++) {
1137- float g_out = to_float (grad_out[batch * D_out + d_out]);
1138- int argmax_n = argmax_indices[batch * D_out + d_out];
1139- grad_x[batch * N * D_in + argmax_n * D_in + d_in] += g_out * W[d_out * D_in + d_in];
1140- }
1141- }
11421026
11431027
11441028#define SELECT_COPY_THREADS 256
@@ -1410,39 +1294,11 @@ __global__ void puff_advantage_kernel_scalar(const precision_t* values, const pr
14101294// Host wrappers live in models.cu
14111295// ============================================================================
14121296
1413- __global__ void cast_bf16_to_f32_kernel (float * __restrict__ dst, const __nv_bfloat16* __restrict__ src, int n) {
1414- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1415- if (idx < n) dst[idx] = __bfloat162float (src[idx]);
1416- }
1417-
1418- __global__ void cast_f32_to_bf16_kernel (__nv_bfloat16* __restrict__ dst, const float * __restrict__ src, int n) {
1419- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1420- if (idx < n) dst[idx] = __float2bfloat16 (src[idx]);
1421- }
1422-
1423- __global__ void cast_f32_to_bf16_transpose_kernel (__nv_bfloat16* __restrict__ dst, const float * __restrict__ src, int R, int C) {
1424- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1425- if (idx >= R * C) return ;
1426- dst[(idx % C) * R + idx / C] = __float2bfloat16 (src[idx]);
1427- }
1428-
1429- __global__ void transpose_f32_kernel (float * __restrict__ dst, const float * __restrict__ src, int R, int C) {
1430- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1431- if (idx >= R * C) return ;
1432- dst[(idx % C) * R + idx / C] = src[idx];
1433- }
1434-
14351297__global__ void cast_f32_to_precision_kernel (precision_t * __restrict__ dst, const float * __restrict__ src, int n) {
14361298 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
14371299 if (idx < n) dst[idx] = from_float (src[idx]);
14381300}
14391301
1440- __global__ void cast_f32_transpose_to_precision_kernel (precision_t * __restrict__ dst, const float * __restrict__ src, int R, int C) {
1441- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1442- if (idx >= R * C) return ;
1443- dst[(idx % C) * R + idx / C] = from_float (src[idx]);
1444- }
1445-
14461302template <typename T>
14471303__global__ void transpose_01_kernel (T* __restrict__ dst, const T* __restrict__ src, int A, int B, int C) {
14481304 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -1452,20 +1308,6 @@ __global__ void transpose_01_kernel(T* __restrict__ dst, const T* __restrict__ s
14521308 dst[b * A * C + a * C + c] = src[idx];
14531309}
14541310
1455- __global__ void norm_bf16_kernel (float * __restrict__ partials, const __nv_bfloat16* __restrict__ src, int n) {
1456- __shared__ float sdata[256 ];
1457- int tid = threadIdx .x ;
1458- float sum = 0 .0f ;
1459- for (int i = blockIdx .x * blockDim .x + tid; i < n; i += blockDim .x * gridDim .x ) {
1460- float v = __bfloat162float (src[i]);
1461- sum += v * v;
1462- }
1463- sdata[tid] = sum;
1464- __syncthreads ();
1465- for (int s = blockDim .x / 2 ; s > 0 ; s >>= 1 ) { if (tid < s) sdata[tid] += sdata[tid + s]; __syncthreads (); }
1466- if (tid == 0 ) partials[blockIdx .x ] = sdata[0 ];
1467- }
1468-
14691311__global__ void norm_f32_kernel (float * __restrict__ partials, const float * __restrict__ src, int n) {
14701312 __shared__ float sdata[256 ];
14711313 int tid = threadIdx .x ;
@@ -1493,18 +1335,6 @@ __global__ void clip_by_norm_f32_kernel(float* __restrict__ dst, const float* __
14931335 if (idx < n) dst[idx] *= clip_coef;
14941336}
14951337
1496- __global__ void normalize_bf16_kernel (__nv_bfloat16* __restrict__ dst, const float * __restrict__ norm_ptr, float eps, int n) {
1497- float inv_norm = 1 .0f / fmaxf (sqrtf (*norm_ptr), eps);
1498- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1499- if (idx < n) dst[idx] = __float2bfloat16 (__bfloat162float (dst[idx]) * inv_norm);
1500- }
1501-
1502- __global__ void normalize_f32_kernel (float * __restrict__ dst, const float * __restrict__ norm_ptr, float eps, int n) {
1503- float inv_norm = 1 .0f / fmaxf (sqrtf (*norm_ptr), eps);
1504- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1505- if (idx < n) dst[idx] = dst[idx] * inv_norm;
1506- }
1507-
15081338__global__ void norm_precision_kernel (float * __restrict__ partials, const precision_t * __restrict__ src, int n) {
15091339 __shared__ float sdata[256 ];
15101340 int tid = threadIdx .x ;
@@ -1530,17 +1360,6 @@ __global__ void cast_precision_to_f32_kernel(float* __restrict__ dst, const prec
15301360 if (idx < n) dst[idx] = to_float (src[idx]);
15311361}
15321362
1533- __global__ void cast_precision_scale_to_f32_kernel (float * __restrict__ dst, const precision_t * __restrict__ src, float scale, int n) {
1534- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1535- if (idx < n) dst[idx] = to_float (src[idx]) * scale;
1536- }
1537-
1538- __global__ void cast_precision_scale_transpose_to_f32_kernel (float * __restrict__ dst, const precision_t * __restrict__ src, float scale, int R, int C) {
1539- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1540- if (idx >= R * C) return ;
1541- dst[(idx % C) * R + idx / C] = to_float (src[idx]) * scale;
1542- }
1543-
15441363// Input: (R, C) f32 → (M, N) precision_t, optionally transposing
15451364__global__ void cast_f32_to_precision_2d_kernel (precision_t * __restrict__ dst, const float * __restrict__ src, bool do_transpose, int R, int C) {
15461365 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -1557,34 +1376,14 @@ __global__ void cast_precision_scale_to_f32_2d_kernel(float* __restrict__ dst, c
15571376 dst[out_idx] = to_float (src[idx]) * scale;
15581377}
15591378
1560- __global__ void fill_bf16_kernel (__nv_bfloat16* __restrict__ dst, __nv_bfloat16 val, int n) {
1561- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1562- if (idx < n) dst[idx] = val;
1563- }
1564-
1565- __global__ void fill_f32_kernel (float * __restrict__ dst, float val, int n) {
1379+ __global__ void fill_precision_kernel (precision_t * __restrict__ dst, precision_t val, int n) {
15661380 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
15671381 if (idx < n) dst[idx] = val;
15681382}
15691383
1570- __global__ void clamp_bf16_kernel (__nv_bfloat16* __restrict__ dst, float lo, float hi, int n) {
1571- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1572- if (idx < n) { float v = __bfloat162float (dst[idx]); dst[idx] = __float2bfloat16 (fminf (fmaxf (v, lo), hi)); }
1573- }
1574-
1575- __global__ void clamp_f32_kernel (float * __restrict__ dst, float lo, float hi, int n) {
1576- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1577- if (idx < n) { dst[idx] = fminf (fmaxf (dst[idx], lo), hi); }
1578- }
1579-
1580- __global__ void scale_f32_kernel (float * __restrict__ dst, float alpha, int n) {
1581- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1582- if (idx < n) dst[idx] *= alpha;
1583- }
1584-
1585- __global__ void axpy_f32_kernel (float * __restrict__ dst, const float * __restrict__ src, float alpha, int n) {
1384+ __global__ void clamp_precision_kernel (precision_t * __restrict__ dst, float lo, float hi, int n) {
15861385 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1587- if (idx < n) dst[idx] += alpha * src[idx];
1386+ if (idx < n) { float v = to_float ( dst[idx]); dst[idx] = from_float ( fminf ( fmaxf (v, lo), hi)); }
15881387}
15891388
15901389// Fused Nesterov momentum: mb = mu*mb + gc; gc = gc + mu*mb
@@ -1597,24 +1396,6 @@ __global__ void nesterov_f32_kernel(float* __restrict__ mb, float* __restrict__
15971396 }
15981397}
15991398
1600- __global__ void scale_f32_dev_kernel (float * __restrict__ dst, const float * __restrict__ alpha_ptr, int n) {
1601- float alpha = *alpha_ptr;
1602- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1603- if (idx < n) dst[idx] *= alpha;
1604- }
1605-
1606- __global__ void axpy_f32_dev_kernel (float * __restrict__ dst, const float * __restrict__ src, const float * __restrict__ alpha_ptr, int n) {
1607- float alpha = *alpha_ptr;
1608- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1609- if (idx < n) dst[idx] += alpha * src[idx];
1610- }
1611-
1612- __global__ void compute_lr_scalars_kernel (const float * __restrict__ lr, float wd,
1613- float * __restrict__ neg_lr, float * __restrict__ wd_scale) {
1614- *neg_lr = -(*lr);
1615- *wd_scale = 1 .0f - (*lr) * wd;
1616- }
1617-
16181399// Fused weight update: wb = wb * (1 - lr*wd) - lr * up
16191400__global__ void muon_weight_update_kernel (float * __restrict__ wb, const float * __restrict__ up,
16201401 const float * __restrict__ lr_ptr, float wd, int n) {
@@ -1626,63 +1407,32 @@ __global__ void muon_weight_update_kernel(float* __restrict__ wb, const float* _
16261407 }
16271408}
16281409
1629- __global__ void add_bf16_to_f32_kernel (float * __restrict__ dst, const __nv_bfloat16 * __restrict__ src, int n) {
1410+ __global__ void add_precision_to_f32_kernel (float * __restrict__ dst, const precision_t * __restrict__ src, int n) {
16301411 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1631- if (idx < n) dst[idx] += __bfloat162float (src[idx]);
1412+ if (idx < n) dst[idx] += to_float (src[idx]);
16321413}
16331414
16341415__global__ void add_precision_kernel (precision_t * __restrict__ dst, const precision_t * __restrict__ src, int n) {
16351416 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
16361417 if (idx < n) dst[idx] = from_float (to_float (dst[idx]) + to_float (src[idx]));
16371418}
16381419
1639- __global__ void add_f32_kernel (float * __restrict__ dst, const float * __restrict__ src, int n) {
1640- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1641- if (idx < n) dst[idx] += src[idx];
1642- }
1643-
1644- __global__ void sum_rows_add_kernel (float * __restrict__ dst, const float * __restrict__ src, int R, int C) {
1645- int col = blockIdx .x * blockDim .x + threadIdx .x ;
1646- if (col >= C) return ;
1647- float sum = 0 .0f ;
1648- for (int r = 0 ; r < R; r++) sum += src[r * C + col];
1649- dst[col] += sum;
1650- }
1651-
16521420// Sum f32 rows → bf16 output (set, not accumulate)
1653- __global__ void sum_rows_to_bf16_kernel (__nv_bfloat16* __restrict__ dst, const float * __restrict__ src, int R, int C) {
1654- int col = blockIdx .x * blockDim .x + threadIdx .x ;
1655- if (col >= C) return ;
1656- float sum = 0 .0f ;
1657- for (int r = 0 ; r < R; r++) sum += src[r * C + col];
1658- dst[col] = __float2bfloat16 (sum);
1659- }
1660-
1661- // Sum f32 rows → f32 output (set, not accumulate)
1662- __global__ void sum_rows_to_f32_kernel (float * __restrict__ dst, const float * __restrict__ src, int R, int C) {
1421+ __global__ void sum_rows_to_precision_kernel (precision_t * __restrict__ dst, const float * __restrict__ src, int R, int C) {
16631422 int col = blockIdx .x * blockDim .x + threadIdx .x ;
16641423 if (col >= C) return ;
16651424 float sum = 0 .0f ;
16661425 for (int r = 0 ; r < R; r++) sum += src[r * C + col];
1667- dst[col] = sum;
1426+ dst[col] = from_float ( sum) ;
16681427}
16691428
1670- __global__ void assemble_decoder_grad_bf16_kernel (
1671- __nv_bfloat16 * __restrict__ dst, const float * __restrict__ grad_logits,
1429+ __global__ void assemble_decoder_grad_kernel (
1430+ precision_t * __restrict__ dst, const float * __restrict__ grad_logits,
16721431 const float * __restrict__ grad_value, int B_TT, int od, int od_plus_1) {
16731432 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
16741433 if (idx >= B_TT * od_plus_1) return ;
16751434 int row = idx / od_plus_1, col = idx % od_plus_1;
1676- dst[idx] = __float2bfloat16 ((col < od) ? grad_logits[row * od + col] : grad_value[row]);
1677- }
1678-
1679- __global__ void assemble_decoder_grad_f32_kernel (
1680- float * __restrict__ dst, const float * __restrict__ grad_logits,
1681- const float * __restrict__ grad_value, int B_TT, int od, int od_plus_1) {
1682- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
1683- if (idx >= B_TT * od_plus_1) return ;
1684- int row = idx / od_plus_1, col = idx % od_plus_1;
1685- dst[idx] = (col < od) ? grad_logits[row * od + col] : grad_value[row];
1435+ dst[idx] = from_float ((col < od) ? grad_logits[row * od + col] : grad_value[row]);
16861436}
16871437
16881438__global__ void var_mean_kernel (const float * __restrict__ src, float * __restrict__ var_out,
@@ -1705,10 +1455,6 @@ __global__ void var_mean_kernel(const float* __restrict__ src, float* __restrict
17051455 if (tid == 0 ) *var_out = sdata[0 ] / (float )(n - 1 );
17061456}
17071457
1708- __global__ void add_scalar_kernel (float * __restrict__ ptr, float val) {
1709- *ptr += val;
1710- }
1711-
17121458__global__ void index_copy_kernel (char * __restrict__ dst, const int64_t * __restrict__ idx,
17131459 const char * __restrict__ src, int num_idx, int row_bytes) {
17141460 int i = blockIdx .x * blockDim .x + threadIdx .x ;
0 commit comments