Files
nvfp4-megamoe-kernel/dsv4/kernels/cuda/compressor_reduce.cu
biondizzle 84ca520bfb fix: move compressor position_bias into CUDA kernel (was Python loop)
The compressor_reduce.cu kernel now adds position_bias to BOTH kv and
gate values, matching the PyTorch reference. Previously the kernel only
added it to gate, and a Python workaround loop was adding it to both
before the kernel call (then passing None to the kernel).

Changes:
- compressor_reduce.cu: add position_bias to kv_val in pass 2 (CSA + HCA)
- single_shot_inference.py: remove Python position_bias loop, pass
  self.ape directly to csa/hca_compress_production
- production_compress.py: already supports position_bias passthrough
2026-06-01 05:54:44 +00:00

349 lines
12 KiB
Plaintext

/**
* Compressor reduce kernels for DSV4 CSA and HCA.
*
* Takes the OUTPUT of the NVFP4 GEMM projections (kv_proj, gate_proj)
* and performs the token-level softmax + weighted sum reduction.
*
* CSA (paper eq. 11-12):
* kv_proj output: (T, 2*hd) — Ca (first hd) and Cb (second hd)
* gate_proj output: (T, 2*hd) — Ga (first hd) and Gb (second hd)
* For block i: if i > 0, concat Ca[i-1] + Cb[i] and Ga[i-1] + Gb[i]
* else just Cb[0] and Gb[0]
* compressed[i] = softmax(gate_block, dim=0) * kv_block summed over tokens
*
* HCA (paper eq. 9-10):
* kv_proj output: (T, hd)
* gate_proj output: (T, hd)
* For block i: kv_block = kv[i*m : (i+1)*m], gate_block = gate[i*m : (i+1)*m]
* compressed[i] = softmax(gate_block, dim=0) * kv_block summed over tokens
*
* Both kernels also apply kv_norm (unweighted RMSNorm) if weight is provided.
*
* One block per compressed output entry. 128 threads per block.
* Each thread processes a strided subset of columns.
* FP32 accumulation throughout. No extern shared memory needed.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAException.h>
#include <cmath>
// Block-level sum reduction (for kv_norm)
__device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_warps) {
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
if (threadIdx.x % 32 == 0) {
smem[threadIdx.x / 32] = val;
}
__syncthreads();
float result = 0.0f;
if (threadIdx.x < 32) {
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f;
for (int offset = 16; offset > 0; offset >>= 1) {
v += __shfl_down_sync(0xffffffff, v, offset);
}
result = v;
}
__syncthreads();
return result;
}
// ===========================================================================
// CSA compressor reduce kernel
// ===========================================================================
__global__ void csa_compress_reduce_kernel(
const float* __restrict__ kv_proj, // [T, 2*hd] FP32 (Ca | Cb)
const float* __restrict__ gate_proj, // [T, 2*hd] FP32 (Ga | Gb)
const float* __restrict__ position_bias, // [m, 2*hd] FP32 or nullptr
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr (unused here, applied separately)
float* __restrict__ compressed, // [n_blocks, hd] FP32
int T, int hd, int m, int n_blocks
) {
int block_i = blockIdx.x;
int tid = threadIdx.x;
int n_threads = blockDim.x;
int kv_dim = 2 * hd;
if (block_i >= n_blocks) return;
int n_tokens = (block_i > 0) ? 2 * m : m;
int prev_start = (block_i - 1) * m;
int cur_start = block_i * m;
// Each thread processes columns [tid, tid+n_threads, tid+2*n_threads, ...]
// Max cols per thread for hd=512, 128 threads = 4
int cols_per_thread = (hd + n_threads - 1) / n_threads;
float local_max[4];
float local_denom[4];
float local_acc[4];
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
local_max[ci] = -FLT_MAX;
local_denom[ci] = 0.0f;
local_acc[ci] = 0.0f;
// Pass 1: find max gate value
for (int t = 0; t < n_tokens; t++) {
int token_idx, gate_offset;
if (block_i > 0) {
if (t < m) { token_idx = prev_start + t; gate_offset = 0; }
else { token_idx = cur_start + (t - m); gate_offset = hd; }
} else {
token_idx = t; gate_offset = hd;
}
if (token_idx < 0 || token_idx >= T) continue;
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
// Position bias: same (m, 2*hd) bias added to every block
if (position_bias != nullptr) {
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
if (pos_bias_row >= 0 && pos_bias_row < m) {
g += position_bias[pos_bias_row * kv_dim + gate_offset + c];
}
}
local_max[ci] = fmaxf(local_max[ci], g);
}
// Pass 2: exp sum + weighted sum
for (int t = 0; t < n_tokens; t++) {
int token_idx, kv_offset, gate_offset;
if (block_i > 0) {
if (t < m) { token_idx = prev_start + t; kv_offset = 0; gate_offset = 0; }
else { token_idx = cur_start + (t - m); kv_offset = hd; gate_offset = hd; }
} else {
token_idx = t; kv_offset = hd; gate_offset = hd;
}
if (token_idx < 0 || token_idx >= T) continue;
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c];
// Position bias: same (m, 2*hd) bias added to every block
// Added to BOTH gate (softmax logit) and kv (content) per reference
if (position_bias != nullptr) {
int pos_bias_row = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
if (pos_bias_row >= 0 && pos_bias_row < m) {
float pb = position_bias[pos_bias_row * kv_dim + gate_offset + c];
g += pb;
// kv_offset matches gate_offset for CSA: both are 0 (a-stream) or hd (b-stream)
kv_val += position_bias[pos_bias_row * kv_dim + kv_offset + c];
}
}
float e = expf(g - local_max[ci]);
local_denom[ci] += e;
local_acc[ci] += e * kv_val;
}
float val = (local_denom[ci] > 0.0f) ? (local_acc[ci] / local_denom[ci]) : 0.0f;
compressed[block_i * hd + c] = val;
}
}
// ===========================================================================
// HCA compressor reduce kernel (no overlap, single stream)
// ===========================================================================
__global__ void hca_compress_reduce_kernel(
const float* __restrict__ kv_proj, // [T, hd] FP32
const float* __restrict__ gate_proj, // [T, hd] FP32
const float* __restrict__ position_bias, // [m, hd] FP32 or nullptr
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr (unused here)
float* __restrict__ compressed, // [n_blocks, hd] FP32
int T, int hd, int m, int n_blocks
) {
int block_i = blockIdx.x;
int tid = threadIdx.x;
int n_threads = blockDim.x;
if (block_i >= n_blocks) return;
int cols_per_thread = (hd + n_threads - 1) / n_threads;
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
float local_max = -FLT_MAX;
float local_denom = 0.0f;
float local_acc = 0.0f;
int start = block_i * m;
// Pass 1: max
for (int t = 0; t < m; t++) {
int token_idx = start + t;
if (token_idx >= T) break;
float g = gate_proj[token_idx * hd + c];
if (position_bias != nullptr && t < m) {
g += position_bias[t * hd + c];
}
local_max = fmaxf(local_max, g);
}
// Pass 2: exp + weighted sum
for (int t = 0; t < m; t++) {
int token_idx = start + t;
if (token_idx >= T) break;
float g = gate_proj[token_idx * hd + c];
float kv_val = kv_proj[token_idx * hd + c];
// Position bias: same (m, hd) bias added to every block
// Added to BOTH gate (softmax logit) and kv (content) per reference
if (position_bias != nullptr && t < m) {
float pb = position_bias[t * hd + c];
g += pb;
kv_val += pb;
}
float e = expf(g - local_max);
local_denom += e;
local_acc += e * kv_val;
}
float val = (local_denom > 0.0f) ? (local_acc / local_denom) : 0.0f;
compressed[block_i * hd + c] = val;
}
}
// ===========================================================================
// Unweighted RMSNorm kernel (applied after compress reduce)
// ===========================================================================
__global__ void apply_kv_norm_kernel(
const float* __restrict__ input, // [n_blocks, hd] FP32
const float* __restrict__ norm_weight, // [hd] FP32
float* __restrict__ output, // [n_blocks, hd] FP32 (can be same as input)
int n_blocks, int hd
) {
int block_i = blockIdx.x;
int tid = threadIdx.x;
int n_threads = blockDim.x;
int n_warps = n_threads / 32;
if (block_i >= n_blocks) return;
// Compute sum of squares for this block
float local_sq = 0.0f;
for (int c = tid; c < hd; c += n_threads) {
float v = input[block_i * hd + c];
local_sq += v * v;
}
__shared__ float s_sum;
float total_sq = block_reduce_sum(local_sq, &s_sum, n_warps);
__shared__ float s_inv_rms;
if (tid == 0) {
float mean_sq = total_sq / hd;
s_inv_rms = rsqrtf(mean_sq + 1e-6f);
}
__syncthreads();
for (int c = tid; c < hd; c += n_threads) {
output[block_i * hd + c] = input[block_i * hd + c] * s_inv_rms * norm_weight[c];
}
}
// ===========================================================================
// PyTorch bindings
// ===========================================================================
void csa_compress_reduce_cuda(
torch::Tensor kv_proj, // [T, 2*hd] FP32
torch::Tensor gate_proj, // [T, 2*hd] FP32
torch::Tensor position_bias, // [m, 2*hd] FP32 or empty
torch::Tensor kv_norm_weight, // [hd] FP32 or empty
torch::Tensor compressed, // [n_blocks, hd] FP32
int64_t m, int64_t n_blocks
) {
int T = kv_proj.size(0);
int hd = compressed.size(1);
int threads = 128;
TORCH_CHECK(kv_proj.scalar_type() == torch::kFloat32, "kv_proj must be float32");
TORCH_CHECK(gate_proj.scalar_type() == torch::kFloat32, "gate_proj must be float32");
const float* pos_bias_ptr = nullptr;
if (position_bias.numel() > 0) {
pos_bias_ptr = position_bias.data_ptr<float>();
}
const float* norm_ptr = nullptr;
if (kv_norm_weight.numel() > 0) {
norm_ptr = kv_norm_weight.data_ptr<float>();
}
csa_compress_reduce_kernel<<<n_blocks, threads>>>(
kv_proj.data_ptr<float>(),
gate_proj.data_ptr<float>(),
pos_bias_ptr,
norm_ptr,
compressed.data_ptr<float>(),
T, hd, (int)m, (int)n_blocks
);
C10_CUDA_CHECK(cudaGetLastError());
// Apply kv_norm if provided
if (norm_ptr != nullptr) {
apply_kv_norm_kernel<<<n_blocks, threads>>>(
compressed.data_ptr<float>(),
norm_ptr,
compressed.data_ptr<float>(),
(int)n_blocks, hd
);
C10_CUDA_CHECK(cudaGetLastError());
}
}
void hca_compress_reduce_cuda(
torch::Tensor kv_proj, // [T, hd] FP32
torch::Tensor gate_proj, // [T, hd] FP32
torch::Tensor position_bias, // [m, hd] FP32 or empty
torch::Tensor kv_norm_weight, // [hd] FP32 or empty
torch::Tensor compressed, // [n_blocks, hd] FP32
int64_t m, int64_t n_blocks
) {
int T = kv_proj.size(0);
int hd = compressed.size(1);
int threads = 128;
TORCH_CHECK(kv_proj.scalar_type() == torch::kFloat32, "kv_proj must be float32");
TORCH_CHECK(gate_proj.scalar_type() == torch::kFloat32, "gate_proj must be float32");
const float* pos_bias_ptr = nullptr;
if (position_bias.numel() > 0) {
pos_bias_ptr = position_bias.data_ptr<float>();
}
const float* norm_ptr = nullptr;
if (kv_norm_weight.numel() > 0) {
norm_ptr = kv_norm_weight.data_ptr<float>();
}
hca_compress_reduce_kernel<<<n_blocks, threads>>>(
kv_proj.data_ptr<float>(),
gate_proj.data_ptr<float>(),
pos_bias_ptr,
norm_ptr,
compressed.data_ptr<float>(),
T, hd, (int)m, (int)n_blocks
);
C10_CUDA_CHECK(cudaGetLastError());
if (norm_ptr != nullptr) {
apply_kv_norm_kernel<<<n_blocks, threads>>>(
compressed.data_ptr<float>(),
norm_ptr,
compressed.data_ptr<float>(),
(int)n_blocks, hd
);
C10_CUDA_CHECK(cudaGetLastError());
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("csa_compress_reduce", &csa_compress_reduce_cuda, "CSA compress reduce kernel");
m.def("hca_compress_reduce", &hca_compress_reduce_cuda, "HCA compress reduce kernel");
}