fix: rewrite compressor_reduce.cu — no extern shared mem, proper bounds checks

This commit is contained in:
2026-06-01 05:24:18 +00:00
parent 62041b78bf
commit df8acae66b

View File

@@ -5,8 +5,8 @@
* and performs the token-level softmax + weighted sum reduction.
*
* CSA (paper eq. 11-12):
* kv_proj output: (T, 2*hd) — split into Ca (first hd) and Cb (second hd)
* gate_proj output: (T, 2*hd) — split into Ga (first hd) and Gb (second hd)
* kv_proj output: (T, 2*hd) — Ca (first hd) and Cb (second hd)
* gate_proj output: (T, 2*hd) — Ga (first hd) and Gb (second hd)
* For block i: if i > 0, concat Ca[i-1] + Cb[i] and Ga[i-1] + Gb[i]
* else just Cb[0] and Gb[0]
* compressed[i] = softmax(gate_block, dim=0) * kv_block summed over tokens
@@ -17,12 +17,11 @@
* For block i: kv_block = kv[i*m : (i+1)*m], gate_block = gate[i*m : (i+1)*m]
* compressed[i] = softmax(gate_block, dim=0) * kv_block summed over tokens
*
* Both kernels also apply position_bias (cyclic per block) if provided,
* and kv_norm (unweighted RMSNorm) if weight is provided.
* Both kernels also apply kv_norm (unweighted RMSNorm) if weight is provided.
*
* One block per compressed output entry. 128 threads per block.
* head_dim=512: 128 threads * 4 elements/thread covers it.
* FP32 accumulation throughout.
* Each thread processes a strided subset of columns.
* FP32 accumulation throughout. No extern shared memory needed.
*/
#include <cuda.h>
@@ -31,9 +30,8 @@
#include <c10/cuda/CUDAException.h>
#include <cmath>
// Block-level sum reduction
// Block-level sum reduction (for kv_norm)
__device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_warps) {
// Warp-level reduce
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
@@ -53,148 +51,34 @@ __device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_
return result;
}
__device__ __forceinline__ float block_reduce_max(float val, float* smem, int n_warps) {
for (int offset = 16; offset > 0; offset >>= 1) {
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
}
if (threadIdx.x % 32 == 0) {
smem[threadIdx.x / 32] = val;
}
__syncthreads();
float result = 0.0f;
if (threadIdx.x < 32) {
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : -FLT_MAX;
for (int offset = 16; offset > 0; offset >>= 1) {
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset));
}
result = v;
}
__syncthreads();
return result;
}
// ===========================================================================
// CSA compressor reduce kernel
// ===========================================================================
__global__ void csa_compress_reduce_kernel(
// Inputs — output of NVFP4 GEMM projections
const float* __restrict__ kv_proj, // [T, 2*hd] FP32 (Ca | Cb)
const float* __restrict__ gate_proj, // [T, 2*hd] FP32 (Ga | Gb)
const float* __restrict__ position_bias, // [m, 2*hd] FP32 or nullptr
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr
// Output
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr (unused here, applied separately)
float* __restrict__ compressed, // [n_blocks, hd] FP32
// Geometry
int T, int hd, int m, int n_blocks
) {
int block_i = blockIdx.x;
int tid = threadIdx.x;
int n_threads = blockDim.x;
int n_warps = n_threads / 32;
int kv_dim = 2 * hd;
if (block_i >= n_blocks) return;
// Each block: 2*m tokens (m from previous Ca + m from current Cb) for i>0
// m tokens (just Cb) for i==0
int n_tokens;
if (block_i > 0) {
n_tokens = 2 * m;
} else {
n_tokens = m;
}
// Per-column processing: each thread handles multiple columns
// We accumulate: for each col, max(gate), sum(exp(gate - max)), sum(exp*kv)
// Output: compressed[block_i, col] = sum(exp*kv) / sum(exp)
// Shared memory for per-column partials
extern __shared__ char smem_buf[];
float* s_max = reinterpret_cast<float*>(smem_buf); // [hd]
float* s_denom = s_max + hd; // [hd]
float* s_acc = s_denom + hd; // [hd]
// Initialize shared accumulators
for (int c = tid; c < hd; c += n_threads) {
s_max[c] = -FLT_MAX;
s_denom[c] = 0.0f;
s_acc[c] = 0.0f;
}
__syncthreads();
// Token range for this block
int cur_start = block_i * m; // start of current block's tokens in the T-dim
int prev_start = (block_i - 1) * m; // start of previous block's tokens
// Pass 1: find max gate value per column
for (int t = 0; t < n_tokens; t++) {
int token_idx, kv_offset, gate_offset;
if (block_i > 0) {
if (t < m) {
// Previous block's a-stream: Ca[t], Ga[t] — first hd columns
token_idx = prev_start + t;
kv_offset = 0; // Ca is columns [0, hd)
gate_offset = 0; // Ga is columns [0, hd)
} else {
// Current block's b-stream: Cb[t], Gb[t] — second hd columns
token_idx = cur_start + (t - m);
kv_offset = hd; // Cb is columns [hd, 2*hd)
gate_offset = hd; // Gb is columns [hd, 2*hd)
}
} else {
// Block 0: just Cb, Gb — second half
token_idx = t;
kv_offset = hd;
gate_offset = hd;
}
for (int c = tid; c < hd; c += n_threads) {
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
if (position_bias != nullptr) {
int pos_bias_row = (block_i > 0 && t < m) ? (m + t) : (block_i > 0 ? (t - m) : t);
if (pos_bias_row < m) {
g += position_bias[pos_bias_row * kv_dim + gate_offset + c];
}
}
// Atomic max via CAS loop on shared memory
// Actually, we can just do a serial max since we're writing to s_max[c]
// and each thread writes a different column range. But multiple threads
// might write the same column with different t values.
// Use atomicMax equivalent:
float old = s_max[c];
s_max[c] = fmaxf(old, g);
}
}
__syncthreads();
// Reduce s_max across threads for each column (since multiple threads
// may have written different partial maxes for the same column)
// Actually, we already wrote to s_max[c] with fmaxf, but there are
// data races. Let me use a proper approach: each thread accumulates
// its own max, then we do a block reduction.
// Redo: per-thread local accumulation, then reduce
// Actually, the issue is that s_max[c] is written by multiple threads
// concurrently. Let's use atomicCAS or a different approach.
// Simpler: each thread processes a SUBSET of columns exclusively.
// Actually, the cleanest approach: each thread owns a set of columns,
// processes ALL tokens for those columns, then writes results.
// With hd=512 and 128 threads, each thread owns 4 columns.
// Let me restructure: each thread processes a contiguous range of columns.
// No shared memory needed for accumulation.
// ... This is getting complex. Let me simplify with a column-per-thread approach.
int n_tokens = (block_i > 0) ? 2 * m : m;
int prev_start = (block_i - 1) * m;
int cur_start = block_i * m;
// Each thread processes columns [tid, tid+n_threads, tid+2*n_threads, ...]
// Total columns per thread = ceil(hd / n_threads) = 4 for hd=512, n_threads=128
// Max cols per thread for hd=512, 128 threads = 4
int cols_per_thread = (hd + n_threads - 1) / n_threads;
// Local accumulators
float local_max[4]; // max 4 cols per thread
float local_max[4];
float local_denom[4];
float local_acc[4];
@@ -205,25 +89,22 @@ __global__ void csa_compress_reduce_kernel(
local_denom[ci] = 0.0f;
local_acc[ci] = 0.0f;
// Pass 1: find max
// Pass 1: find max gate value
for (int t = 0; t < n_tokens; t++) {
int token_idx, gate_offset;
if (block_i > 0) {
if (t < m) {
token_idx = prev_start + t;
gate_offset = 0;
} else {
token_idx = cur_start + (t - m);
gate_offset = hd;
}
if (t < m) { token_idx = prev_start + t; gate_offset = 0; }
else { token_idx = cur_start + (t - m); gate_offset = hd; }
} else {
token_idx = t;
gate_offset = hd;
token_idx = t; gate_offset = hd;
}
if (token_idx < 0 || token_idx >= T) continue;
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
// Position bias: same (m, 2*hd) bias added to every block
if (position_bias != nullptr) {
int pos_bias_row = (block_i > 0 && t < m) ? (m + t) : (block_i > 0 ? (t - m) : t);
if (pos_bias_row < m) {
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
if (pos_bias_row >= 0 && pos_bias_row < m) {
g += position_bias[pos_bias_row * kv_dim + gate_offset + c];
}
}
@@ -234,57 +115,39 @@ __global__ void csa_compress_reduce_kernel(
for (int t = 0; t < n_tokens; t++) {
int token_idx, kv_offset, gate_offset;
if (block_i > 0) {
if (t < m) {
token_idx = prev_start + t;
kv_offset = 0;
gate_offset = 0;
} else {
token_idx = cur_start + (t - m);
kv_offset = hd;
gate_offset = hd;
}
if (t < m) { token_idx = prev_start + t; kv_offset = 0; gate_offset = 0; }
else { token_idx = cur_start + (t - m); kv_offset = hd; gate_offset = hd; }
} else {
token_idx = t;
kv_offset = hd;
gate_offset = hd;
token_idx = t; kv_offset = hd; gate_offset = hd;
}
if (token_idx < 0 || token_idx >= T) continue;
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
if (position_bias != nullptr) {
int pos_bias_row = (block_i > 0 && t < m) ? (m + t) : (block_i > 0 ? (t - m) : t);
if (pos_bias_row < m) {
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
if (pos_bias_row >= 0 && pos_bias_row < m) {
g += position_bias[pos_bias_row * kv_dim + gate_offset + c];
}
}
float e = expf(g - local_max[ci]);
local_denom[ci] += e;
float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c];
local_acc[ci] += e * kv_val;
local_acc[ci] += e * kv_proj[token_idx * kv_dim + kv_offset + c];
}
// Normalize
float val = (local_denom[ci] > 0.0f) ? (local_acc[ci] / local_denom[ci]) : 0.0f;
// Apply kv_norm if provided (unweighted RMSNorm + weight)
if (kv_norm_weight != nullptr) {
// We can't do per-element RMSNorm here since we only have one column
// RMSNorm needs the full vector. We need to compute it across all hd columns.
// This requires a separate pass or collective operation.
// For now, store the raw value and apply kv_norm in a separate kernel.
}
compressed[block_i * hd + c] = val;
}
}
// ===========================================================================
// HCA compressor reduce kernel (simpler — no overlap, single stream)
// HCA compressor reduce kernel (no overlap, single stream)
// ===========================================================================
__global__ void hca_compress_reduce_kernel(
const float* __restrict__ kv_proj, // [T, hd] FP32
const float* __restrict__ gate_proj, // [T, hd] FP32
const float* __restrict__ position_bias, // [m, hd] FP32 or nullptr
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr (unused here)
float* __restrict__ compressed, // [n_blocks, hd] FP32
int T, int hd, int m, int n_blocks
) {
@@ -304,7 +167,6 @@ __global__ void hca_compress_reduce_kernel(
float local_denom = 0.0f;
float local_acc = 0.0f;
// Token range: [block_i * m, (block_i + 1) * m)
int start = block_i * m;
// Pass 1: max
@@ -362,7 +224,6 @@ __global__ void apply_kv_norm_kernel(
__shared__ float s_sum;
float total_sq = block_reduce_sum(local_sq, &s_sum, n_warps);
// Only thread 0 has the correct total
__shared__ float s_inv_rms;
if (tid == 0) {
float mean_sq = total_sq / hd;
@@ -371,8 +232,7 @@ __global__ void apply_kv_norm_kernel(
__syncthreads();
for (int c = tid; c < hd; c += n_threads) {
float v = input[block_i * hd + c];
output[block_i * hd + c] = v * s_inv_rms * norm_weight[c];
output[block_i * hd + c] = input[block_i * hd + c] * s_inv_rms * norm_weight[c];
}
}