From df8acae66b9f92afd8a61fa0ee3a784dcb90d93e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 05:24:18 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20rewrite=20compressor=5Freduce.cu=20?= =?UTF-8?q?=E2=80=94=20no=20extern=20shared=20mem,=20proper=20bounds=20che?= =?UTF-8?q?cks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/cuda/compressor_reduce.cu | 204 ++++--------------------- 1 file changed, 32 insertions(+), 172 deletions(-) diff --git a/dsv4/kernels/cuda/compressor_reduce.cu b/dsv4/kernels/cuda/compressor_reduce.cu index 40f06a94..3d7344ee 100644 --- a/dsv4/kernels/cuda/compressor_reduce.cu +++ b/dsv4/kernels/cuda/compressor_reduce.cu @@ -5,8 +5,8 @@ * and performs the token-level softmax + weighted sum reduction. * * CSA (paper eq. 11-12): - * kv_proj output: (T, 2*hd) — split into Ca (first hd) and Cb (second hd) - * gate_proj output: (T, 2*hd) — split into Ga (first hd) and Gb (second hd) + * kv_proj output: (T, 2*hd) — Ca (first hd) and Cb (second hd) + * gate_proj output: (T, 2*hd) — Ga (first hd) and Gb (second hd) * For block i: if i > 0, concat Ca[i-1] + Cb[i] and Ga[i-1] + Gb[i] * else just Cb[0] and Gb[0] * compressed[i] = softmax(gate_block, dim=0) * kv_block summed over tokens @@ -17,12 +17,11 @@ * For block i: kv_block = kv[i*m : (i+1)*m], gate_block = gate[i*m : (i+1)*m] * compressed[i] = softmax(gate_block, dim=0) * kv_block summed over tokens * - * Both kernels also apply position_bias (cyclic per block) if provided, - * and kv_norm (unweighted RMSNorm) if weight is provided. + * Both kernels also apply kv_norm (unweighted RMSNorm) if weight is provided. * * One block per compressed output entry. 128 threads per block. - * head_dim=512: 128 threads * 4 elements/thread covers it. - * FP32 accumulation throughout. + * Each thread processes a strided subset of columns. + * FP32 accumulation throughout. No extern shared memory needed. */ #include @@ -31,9 +30,8 @@ #include #include -// Block-level sum reduction +// Block-level sum reduction (for kv_norm) __device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_warps) { - // Warp-level reduce for (int offset = 16; offset > 0; offset >>= 1) { val += __shfl_down_sync(0xffffffff, val, offset); } @@ -53,148 +51,34 @@ __device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_ return result; } -__device__ __forceinline__ float block_reduce_max(float val, float* smem, int n_warps) { - for (int offset = 16; offset > 0; offset >>= 1) { - val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); - } - if (threadIdx.x % 32 == 0) { - smem[threadIdx.x / 32] = val; - } - __syncthreads(); - float result = 0.0f; - if (threadIdx.x < 32) { - float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : -FLT_MAX; - for (int offset = 16; offset > 0; offset >>= 1) { - v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset)); - } - result = v; - } - __syncthreads(); - return result; -} - // =========================================================================== // CSA compressor reduce kernel // =========================================================================== __global__ void csa_compress_reduce_kernel( - // Inputs — output of NVFP4 GEMM projections const float* __restrict__ kv_proj, // [T, 2*hd] FP32 (Ca | Cb) const float* __restrict__ gate_proj, // [T, 2*hd] FP32 (Ga | Gb) const float* __restrict__ position_bias, // [m, 2*hd] FP32 or nullptr - const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr - // Output + const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr (unused here, applied separately) float* __restrict__ compressed, // [n_blocks, hd] FP32 - // Geometry int T, int hd, int m, int n_blocks ) { int block_i = blockIdx.x; int tid = threadIdx.x; int n_threads = blockDim.x; - int n_warps = n_threads / 32; int kv_dim = 2 * hd; if (block_i >= n_blocks) return; - // Each block: 2*m tokens (m from previous Ca + m from current Cb) for i>0 - // m tokens (just Cb) for i==0 - int n_tokens; - if (block_i > 0) { - n_tokens = 2 * m; - } else { - n_tokens = m; - } - - // Per-column processing: each thread handles multiple columns - // We accumulate: for each col, max(gate), sum(exp(gate - max)), sum(exp*kv) - // Output: compressed[block_i, col] = sum(exp*kv) / sum(exp) - - // Shared memory for per-column partials - extern __shared__ char smem_buf[]; - float* s_max = reinterpret_cast(smem_buf); // [hd] - float* s_denom = s_max + hd; // [hd] - float* s_acc = s_denom + hd; // [hd] - - // Initialize shared accumulators - for (int c = tid; c < hd; c += n_threads) { - s_max[c] = -FLT_MAX; - s_denom[c] = 0.0f; - s_acc[c] = 0.0f; - } - __syncthreads(); - - // Token range for this block - int cur_start = block_i * m; // start of current block's tokens in the T-dim - int prev_start = (block_i - 1) * m; // start of previous block's tokens - - // Pass 1: find max gate value per column - for (int t = 0; t < n_tokens; t++) { - int token_idx, kv_offset, gate_offset; - if (block_i > 0) { - if (t < m) { - // Previous block's a-stream: Ca[t], Ga[t] — first hd columns - token_idx = prev_start + t; - kv_offset = 0; // Ca is columns [0, hd) - gate_offset = 0; // Ga is columns [0, hd) - } else { - // Current block's b-stream: Cb[t], Gb[t] — second hd columns - token_idx = cur_start + (t - m); - kv_offset = hd; // Cb is columns [hd, 2*hd) - gate_offset = hd; // Gb is columns [hd, 2*hd) - } - } else { - // Block 0: just Cb, Gb — second half - token_idx = t; - kv_offset = hd; - gate_offset = hd; - } - - for (int c = tid; c < hd; c += n_threads) { - float g = gate_proj[token_idx * kv_dim + gate_offset + c]; - if (position_bias != nullptr) { - int pos_bias_row = (block_i > 0 && t < m) ? (m + t) : (block_i > 0 ? (t - m) : t); - if (pos_bias_row < m) { - g += position_bias[pos_bias_row * kv_dim + gate_offset + c]; - } - } - // Atomic max via CAS loop on shared memory - // Actually, we can just do a serial max since we're writing to s_max[c] - // and each thread writes a different column range. But multiple threads - // might write the same column with different t values. - // Use atomicMax equivalent: - float old = s_max[c]; - s_max[c] = fmaxf(old, g); - } - } - __syncthreads(); - - // Reduce s_max across threads for each column (since multiple threads - // may have written different partial maxes for the same column) - // Actually, we already wrote to s_max[c] with fmaxf, but there are - // data races. Let me use a proper approach: each thread accumulates - // its own max, then we do a block reduction. - - // Redo: per-thread local accumulation, then reduce - // Actually, the issue is that s_max[c] is written by multiple threads - // concurrently. Let's use atomicCAS or a different approach. - // Simpler: each thread processes a SUBSET of columns exclusively. - - // Actually, the cleanest approach: each thread owns a set of columns, - // processes ALL tokens for those columns, then writes results. - // With hd=512 and 128 threads, each thread owns 4 columns. - - // Let me restructure: each thread processes a contiguous range of columns. - // No shared memory needed for accumulation. - - // ... This is getting complex. Let me simplify with a column-per-thread approach. + int n_tokens = (block_i > 0) ? 2 * m : m; + int prev_start = (block_i - 1) * m; + int cur_start = block_i * m; // Each thread processes columns [tid, tid+n_threads, tid+2*n_threads, ...] - // Total columns per thread = ceil(hd / n_threads) = 4 for hd=512, n_threads=128 - + // Max cols per thread for hd=512, 128 threads = 4 int cols_per_thread = (hd + n_threads - 1) / n_threads; - // Local accumulators - float local_max[4]; // max 4 cols per thread + float local_max[4]; float local_denom[4]; float local_acc[4]; @@ -205,25 +89,22 @@ __global__ void csa_compress_reduce_kernel( local_denom[ci] = 0.0f; local_acc[ci] = 0.0f; - // Pass 1: find max + // Pass 1: find max gate value for (int t = 0; t < n_tokens; t++) { int token_idx, gate_offset; if (block_i > 0) { - if (t < m) { - token_idx = prev_start + t; - gate_offset = 0; - } else { - token_idx = cur_start + (t - m); - gate_offset = hd; - } + if (t < m) { token_idx = prev_start + t; gate_offset = 0; } + else { token_idx = cur_start + (t - m); gate_offset = hd; } } else { - token_idx = t; - gate_offset = hd; + token_idx = t; gate_offset = hd; } + if (token_idx < 0 || token_idx >= T) continue; + float g = gate_proj[token_idx * kv_dim + gate_offset + c]; + // Position bias: same (m, 2*hd) bias added to every block if (position_bias != nullptr) { - int pos_bias_row = (block_i > 0 && t < m) ? (m + t) : (block_i > 0 ? (t - m) : t); - if (pos_bias_row < m) { + int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t); + if (pos_bias_row >= 0 && pos_bias_row < m) { g += position_bias[pos_bias_row * kv_dim + gate_offset + c]; } } @@ -234,57 +115,39 @@ __global__ void csa_compress_reduce_kernel( for (int t = 0; t < n_tokens; t++) { int token_idx, kv_offset, gate_offset; if (block_i > 0) { - if (t < m) { - token_idx = prev_start + t; - kv_offset = 0; - gate_offset = 0; - } else { - token_idx = cur_start + (t - m); - kv_offset = hd; - gate_offset = hd; - } + if (t < m) { token_idx = prev_start + t; kv_offset = 0; gate_offset = 0; } + else { token_idx = cur_start + (t - m); kv_offset = hd; gate_offset = hd; } } else { - token_idx = t; - kv_offset = hd; - gate_offset = hd; + token_idx = t; kv_offset = hd; gate_offset = hd; } + if (token_idx < 0 || token_idx >= T) continue; + float g = gate_proj[token_idx * kv_dim + gate_offset + c]; if (position_bias != nullptr) { - int pos_bias_row = (block_i > 0 && t < m) ? (m + t) : (block_i > 0 ? (t - m) : t); - if (pos_bias_row < m) { + int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t); + if (pos_bias_row >= 0 && pos_bias_row < m) { g += position_bias[pos_bias_row * kv_dim + gate_offset + c]; } } float e = expf(g - local_max[ci]); local_denom[ci] += e; - float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c]; - local_acc[ci] += e * kv_val; + local_acc[ci] += e * kv_proj[token_idx * kv_dim + kv_offset + c]; } - // Normalize float val = (local_denom[ci] > 0.0f) ? (local_acc[ci] / local_denom[ci]) : 0.0f; - - // Apply kv_norm if provided (unweighted RMSNorm + weight) - if (kv_norm_weight != nullptr) { - // We can't do per-element RMSNorm here since we only have one column - // RMSNorm needs the full vector. We need to compute it across all hd columns. - // This requires a separate pass or collective operation. - // For now, store the raw value and apply kv_norm in a separate kernel. - } - compressed[block_i * hd + c] = val; } } // =========================================================================== -// HCA compressor reduce kernel (simpler — no overlap, single stream) +// HCA compressor reduce kernel (no overlap, single stream) // =========================================================================== __global__ void hca_compress_reduce_kernel( const float* __restrict__ kv_proj, // [T, hd] FP32 const float* __restrict__ gate_proj, // [T, hd] FP32 const float* __restrict__ position_bias, // [m, hd] FP32 or nullptr - const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr + const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr (unused here) float* __restrict__ compressed, // [n_blocks, hd] FP32 int T, int hd, int m, int n_blocks ) { @@ -304,7 +167,6 @@ __global__ void hca_compress_reduce_kernel( float local_denom = 0.0f; float local_acc = 0.0f; - // Token range: [block_i * m, (block_i + 1) * m) int start = block_i * m; // Pass 1: max @@ -362,7 +224,6 @@ __global__ void apply_kv_norm_kernel( __shared__ float s_sum; float total_sq = block_reduce_sum(local_sq, &s_sum, n_warps); - // Only thread 0 has the correct total __shared__ float s_inv_rms; if (tid == 0) { float mean_sq = total_sq / hd; @@ -371,8 +232,7 @@ __global__ void apply_kv_norm_kernel( __syncthreads(); for (int c = tid; c < hd; c += n_threads) { - float v = input[block_i * hd + c]; - output[block_i * hd + c] = v * s_inv_rms * norm_weight[c]; + output[block_i * hd + c] = input[block_i * hd + c] * s_inv_rms * norm_weight[c]; } }