fix: rewrite compressor_reduce.cu — no extern shared mem, proper bounds checks
This commit is contained in:
@@ -5,8 +5,8 @@
|
||||
* and performs the token-level softmax + weighted sum reduction.
|
||||
*
|
||||
* CSA (paper eq. 11-12):
|
||||
* kv_proj output: (T, 2*hd) — split into Ca (first hd) and Cb (second hd)
|
||||
* gate_proj output: (T, 2*hd) — split into Ga (first hd) and Gb (second hd)
|
||||
* kv_proj output: (T, 2*hd) — Ca (first hd) and Cb (second hd)
|
||||
* gate_proj output: (T, 2*hd) — Ga (first hd) and Gb (second hd)
|
||||
* For block i: if i > 0, concat Ca[i-1] + Cb[i] and Ga[i-1] + Gb[i]
|
||||
* else just Cb[0] and Gb[0]
|
||||
* compressed[i] = softmax(gate_block, dim=0) * kv_block summed over tokens
|
||||
@@ -17,12 +17,11 @@
|
||||
* For block i: kv_block = kv[i*m : (i+1)*m], gate_block = gate[i*m : (i+1)*m]
|
||||
* compressed[i] = softmax(gate_block, dim=0) * kv_block summed over tokens
|
||||
*
|
||||
* Both kernels also apply position_bias (cyclic per block) if provided,
|
||||
* and kv_norm (unweighted RMSNorm) if weight is provided.
|
||||
* Both kernels also apply kv_norm (unweighted RMSNorm) if weight is provided.
|
||||
*
|
||||
* One block per compressed output entry. 128 threads per block.
|
||||
* head_dim=512: 128 threads * 4 elements/thread covers it.
|
||||
* FP32 accumulation throughout.
|
||||
* Each thread processes a strided subset of columns.
|
||||
* FP32 accumulation throughout. No extern shared memory needed.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
@@ -31,9 +30,8 @@
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cmath>
|
||||
|
||||
// Block-level sum reduction
|
||||
// Block-level sum reduction (for kv_norm)
|
||||
__device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_warps) {
|
||||
// Warp-level reduce
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
}
|
||||
@@ -53,148 +51,34 @@ __device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float block_reduce_max(float val, float* smem, int n_warps) {
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
|
||||
}
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
smem[threadIdx.x / 32] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
float result = 0.0f;
|
||||
if (threadIdx.x < 32) {
|
||||
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : -FLT_MAX;
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset));
|
||||
}
|
||||
result = v;
|
||||
}
|
||||
__syncthreads();
|
||||
return result;
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// CSA compressor reduce kernel
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void csa_compress_reduce_kernel(
|
||||
// Inputs — output of NVFP4 GEMM projections
|
||||
const float* __restrict__ kv_proj, // [T, 2*hd] FP32 (Ca | Cb)
|
||||
const float* __restrict__ gate_proj, // [T, 2*hd] FP32 (Ga | Gb)
|
||||
const float* __restrict__ position_bias, // [m, 2*hd] FP32 or nullptr
|
||||
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr
|
||||
// Output
|
||||
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr (unused here, applied separately)
|
||||
float* __restrict__ compressed, // [n_blocks, hd] FP32
|
||||
// Geometry
|
||||
int T, int hd, int m, int n_blocks
|
||||
) {
|
||||
int block_i = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int n_threads = blockDim.x;
|
||||
int n_warps = n_threads / 32;
|
||||
int kv_dim = 2 * hd;
|
||||
|
||||
if (block_i >= n_blocks) return;
|
||||
|
||||
// Each block: 2*m tokens (m from previous Ca + m from current Cb) for i>0
|
||||
// m tokens (just Cb) for i==0
|
||||
int n_tokens;
|
||||
if (block_i > 0) {
|
||||
n_tokens = 2 * m;
|
||||
} else {
|
||||
n_tokens = m;
|
||||
}
|
||||
|
||||
// Per-column processing: each thread handles multiple columns
|
||||
// We accumulate: for each col, max(gate), sum(exp(gate - max)), sum(exp*kv)
|
||||
// Output: compressed[block_i, col] = sum(exp*kv) / sum(exp)
|
||||
|
||||
// Shared memory for per-column partials
|
||||
extern __shared__ char smem_buf[];
|
||||
float* s_max = reinterpret_cast<float*>(smem_buf); // [hd]
|
||||
float* s_denom = s_max + hd; // [hd]
|
||||
float* s_acc = s_denom + hd; // [hd]
|
||||
|
||||
// Initialize shared accumulators
|
||||
for (int c = tid; c < hd; c += n_threads) {
|
||||
s_max[c] = -FLT_MAX;
|
||||
s_denom[c] = 0.0f;
|
||||
s_acc[c] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Token range for this block
|
||||
int cur_start = block_i * m; // start of current block's tokens in the T-dim
|
||||
int prev_start = (block_i - 1) * m; // start of previous block's tokens
|
||||
|
||||
// Pass 1: find max gate value per column
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
int token_idx, kv_offset, gate_offset;
|
||||
if (block_i > 0) {
|
||||
if (t < m) {
|
||||
// Previous block's a-stream: Ca[t], Ga[t] — first hd columns
|
||||
token_idx = prev_start + t;
|
||||
kv_offset = 0; // Ca is columns [0, hd)
|
||||
gate_offset = 0; // Ga is columns [0, hd)
|
||||
} else {
|
||||
// Current block's b-stream: Cb[t], Gb[t] — second hd columns
|
||||
token_idx = cur_start + (t - m);
|
||||
kv_offset = hd; // Cb is columns [hd, 2*hd)
|
||||
gate_offset = hd; // Gb is columns [hd, 2*hd)
|
||||
}
|
||||
} else {
|
||||
// Block 0: just Cb, Gb — second half
|
||||
token_idx = t;
|
||||
kv_offset = hd;
|
||||
gate_offset = hd;
|
||||
}
|
||||
|
||||
for (int c = tid; c < hd; c += n_threads) {
|
||||
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
|
||||
if (position_bias != nullptr) {
|
||||
int pos_bias_row = (block_i > 0 && t < m) ? (m + t) : (block_i > 0 ? (t - m) : t);
|
||||
if (pos_bias_row < m) {
|
||||
g += position_bias[pos_bias_row * kv_dim + gate_offset + c];
|
||||
}
|
||||
}
|
||||
// Atomic max via CAS loop on shared memory
|
||||
// Actually, we can just do a serial max since we're writing to s_max[c]
|
||||
// and each thread writes a different column range. But multiple threads
|
||||
// might write the same column with different t values.
|
||||
// Use atomicMax equivalent:
|
||||
float old = s_max[c];
|
||||
s_max[c] = fmaxf(old, g);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Reduce s_max across threads for each column (since multiple threads
|
||||
// may have written different partial maxes for the same column)
|
||||
// Actually, we already wrote to s_max[c] with fmaxf, but there are
|
||||
// data races. Let me use a proper approach: each thread accumulates
|
||||
// its own max, then we do a block reduction.
|
||||
|
||||
// Redo: per-thread local accumulation, then reduce
|
||||
// Actually, the issue is that s_max[c] is written by multiple threads
|
||||
// concurrently. Let's use atomicCAS or a different approach.
|
||||
// Simpler: each thread processes a SUBSET of columns exclusively.
|
||||
|
||||
// Actually, the cleanest approach: each thread owns a set of columns,
|
||||
// processes ALL tokens for those columns, then writes results.
|
||||
// With hd=512 and 128 threads, each thread owns 4 columns.
|
||||
|
||||
// Let me restructure: each thread processes a contiguous range of columns.
|
||||
// No shared memory needed for accumulation.
|
||||
|
||||
// ... This is getting complex. Let me simplify with a column-per-thread approach.
|
||||
int n_tokens = (block_i > 0) ? 2 * m : m;
|
||||
int prev_start = (block_i - 1) * m;
|
||||
int cur_start = block_i * m;
|
||||
|
||||
// Each thread processes columns [tid, tid+n_threads, tid+2*n_threads, ...]
|
||||
// Total columns per thread = ceil(hd / n_threads) = 4 for hd=512, n_threads=128
|
||||
|
||||
// Max cols per thread for hd=512, 128 threads = 4
|
||||
int cols_per_thread = (hd + n_threads - 1) / n_threads;
|
||||
|
||||
// Local accumulators
|
||||
float local_max[4]; // max 4 cols per thread
|
||||
float local_max[4];
|
||||
float local_denom[4];
|
||||
float local_acc[4];
|
||||
|
||||
@@ -205,25 +89,22 @@ __global__ void csa_compress_reduce_kernel(
|
||||
local_denom[ci] = 0.0f;
|
||||
local_acc[ci] = 0.0f;
|
||||
|
||||
// Pass 1: find max
|
||||
// Pass 1: find max gate value
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
int token_idx, gate_offset;
|
||||
if (block_i > 0) {
|
||||
if (t < m) {
|
||||
token_idx = prev_start + t;
|
||||
gate_offset = 0;
|
||||
} else {
|
||||
token_idx = cur_start + (t - m);
|
||||
gate_offset = hd;
|
||||
}
|
||||
if (t < m) { token_idx = prev_start + t; gate_offset = 0; }
|
||||
else { token_idx = cur_start + (t - m); gate_offset = hd; }
|
||||
} else {
|
||||
token_idx = t;
|
||||
gate_offset = hd;
|
||||
token_idx = t; gate_offset = hd;
|
||||
}
|
||||
if (token_idx < 0 || token_idx >= T) continue;
|
||||
|
||||
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
|
||||
// Position bias: same (m, 2*hd) bias added to every block
|
||||
if (position_bias != nullptr) {
|
||||
int pos_bias_row = (block_i > 0 && t < m) ? (m + t) : (block_i > 0 ? (t - m) : t);
|
||||
if (pos_bias_row < m) {
|
||||
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
|
||||
if (pos_bias_row >= 0 && pos_bias_row < m) {
|
||||
g += position_bias[pos_bias_row * kv_dim + gate_offset + c];
|
||||
}
|
||||
}
|
||||
@@ -234,57 +115,39 @@ __global__ void csa_compress_reduce_kernel(
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
int token_idx, kv_offset, gate_offset;
|
||||
if (block_i > 0) {
|
||||
if (t < m) {
|
||||
token_idx = prev_start + t;
|
||||
kv_offset = 0;
|
||||
gate_offset = 0;
|
||||
} else {
|
||||
token_idx = cur_start + (t - m);
|
||||
kv_offset = hd;
|
||||
gate_offset = hd;
|
||||
}
|
||||
if (t < m) { token_idx = prev_start + t; kv_offset = 0; gate_offset = 0; }
|
||||
else { token_idx = cur_start + (t - m); kv_offset = hd; gate_offset = hd; }
|
||||
} else {
|
||||
token_idx = t;
|
||||
kv_offset = hd;
|
||||
gate_offset = hd;
|
||||
token_idx = t; kv_offset = hd; gate_offset = hd;
|
||||
}
|
||||
if (token_idx < 0 || token_idx >= T) continue;
|
||||
|
||||
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
|
||||
if (position_bias != nullptr) {
|
||||
int pos_bias_row = (block_i > 0 && t < m) ? (m + t) : (block_i > 0 ? (t - m) : t);
|
||||
if (pos_bias_row < m) {
|
||||
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
|
||||
if (pos_bias_row >= 0 && pos_bias_row < m) {
|
||||
g += position_bias[pos_bias_row * kv_dim + gate_offset + c];
|
||||
}
|
||||
}
|
||||
float e = expf(g - local_max[ci]);
|
||||
local_denom[ci] += e;
|
||||
float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c];
|
||||
local_acc[ci] += e * kv_val;
|
||||
local_acc[ci] += e * kv_proj[token_idx * kv_dim + kv_offset + c];
|
||||
}
|
||||
|
||||
// Normalize
|
||||
float val = (local_denom[ci] > 0.0f) ? (local_acc[ci] / local_denom[ci]) : 0.0f;
|
||||
|
||||
// Apply kv_norm if provided (unweighted RMSNorm + weight)
|
||||
if (kv_norm_weight != nullptr) {
|
||||
// We can't do per-element RMSNorm here since we only have one column
|
||||
// RMSNorm needs the full vector. We need to compute it across all hd columns.
|
||||
// This requires a separate pass or collective operation.
|
||||
// For now, store the raw value and apply kv_norm in a separate kernel.
|
||||
}
|
||||
|
||||
compressed[block_i * hd + c] = val;
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// HCA compressor reduce kernel (simpler — no overlap, single stream)
|
||||
// HCA compressor reduce kernel (no overlap, single stream)
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void hca_compress_reduce_kernel(
|
||||
const float* __restrict__ kv_proj, // [T, hd] FP32
|
||||
const float* __restrict__ gate_proj, // [T, hd] FP32
|
||||
const float* __restrict__ position_bias, // [m, hd] FP32 or nullptr
|
||||
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr
|
||||
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr (unused here)
|
||||
float* __restrict__ compressed, // [n_blocks, hd] FP32
|
||||
int T, int hd, int m, int n_blocks
|
||||
) {
|
||||
@@ -304,7 +167,6 @@ __global__ void hca_compress_reduce_kernel(
|
||||
float local_denom = 0.0f;
|
||||
float local_acc = 0.0f;
|
||||
|
||||
// Token range: [block_i * m, (block_i + 1) * m)
|
||||
int start = block_i * m;
|
||||
|
||||
// Pass 1: max
|
||||
@@ -362,7 +224,6 @@ __global__ void apply_kv_norm_kernel(
|
||||
|
||||
__shared__ float s_sum;
|
||||
float total_sq = block_reduce_sum(local_sq, &s_sum, n_warps);
|
||||
// Only thread 0 has the correct total
|
||||
__shared__ float s_inv_rms;
|
||||
if (tid == 0) {
|
||||
float mean_sq = total_sq / hd;
|
||||
@@ -371,8 +232,7 @@ __global__ void apply_kv_norm_kernel(
|
||||
__syncthreads();
|
||||
|
||||
for (int c = tid; c < hd; c += n_threads) {
|
||||
float v = input[block_i * hd + c];
|
||||
output[block_i * hd + c] = v * s_inv_rms * norm_weight[c];
|
||||
output[block_i * hd + c] = input[block_i * hd + c] * s_inv_rms * norm_weight[c];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user