KV-1/KV-2: Fused compress+NVFP4 quantize kernels + dequant

- compressor_reduce_quant.cu: Single-kernel CSA/HCA compress + RMSNorm + NVFP4 quantize.
  No intermediate BF16. FP32 → E2M1 + E4M3 + FP32 gsa in one kernel.
  Shared memory: ~2.5KB per CTA (FP32 staging + nibble buffer).

- dequant_nvfp4.cu: NVFP4 → BF16 dequantization kernels.
  Full dequant (HCA dense gather) and selective dequant (CSA top-k gather).
  Single kernel launch per gather operation.

- production_compress.py: Added csa_compress_production_nvfp4() and
  hca_compress_production_nvfp4() — production path for KV-1/KV-2.

- loader.py: Preload dequant_nvfp4 and compressor_reduce_quant modules.

- test_kv_compress_quant.py: Unit tests verifying cos >= 0.999
  between BF16 reference and NVFP4 round-trip path.
This commit is contained in:
2026-06-02 09:37:53 +00:00
parent 107d62dd76
commit f23320b5b2
5 changed files with 896 additions and 25 deletions

View File

@@ -6,6 +6,9 @@ Pipeline:
3. CUDA kernel: token-level softmax(gate) * kv → compressed entries
4. CUDA kernel: kv_norm (unweighted RMSNorm + weight)
KV-1/KV-2: NVFP4 output variants compress + quantize in a single kernel.
No intermediate BF16. Stored as FP4 data + E4M3 block scales + FP32 global scale.
No PyTorch softmax. No reference fallback. All on the GPU.
"""
@@ -40,18 +43,7 @@ def csa_compress_production(
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
m: int = 4,
) -> torch.Tensor:
"""CSA compress: softmax + weighted sum + kv_norm.
Args:
kv_proj_out: FP32 projection output, (T, 2*hd), Ca in first hd cols, Cb in second
gate_proj_out: FP32 projection output, (T, 2*hd), Ga in first hd cols, Gb in second
position_bias: (m, 2*hd) BF16 position bias, or None
kv_norm_weight: (hd) BF16 norm weight, or None
m: compression ratio (4 for CSA)
Returns:
compressed: (n_blocks, hd) BF16
"""
"""CSA compress: softmax + weighted sum + kv_norm. Returns BF16."""
T = kv_proj_out.shape[0]
hd = kv_proj_out.shape[1] // 2
n_blocks = T // m
@@ -60,7 +52,6 @@ def csa_compress_production(
mod = _get_kernel()
# Convert position_bias and kv_norm_weight to FP32
pos_bias_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
if position_bias is not None:
pos_bias_f32 = position_bias.float()
@@ -90,18 +81,7 @@ def hca_compress_production(
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
m: int = 128,
) -> torch.Tensor:
"""HCA compress: softmax + weighted sum + kv_norm.
Args:
kv_proj_out: FP32 projection output, (T, hd)
gate_proj_out: FP32 projection output, (T, hd)
position_bias: (m, hd) BF16 position bias, or None
kv_norm_weight: (hd) BF16 norm weight, or None
m: compression ratio (128 for HCA)
Returns:
compressed: (n_blocks, hd) BF16
"""
"""HCA compress: softmax + weighted sum + kv_norm. Returns BF16."""
T = kv_proj_out.shape[0]
hd = kv_proj_out.shape[1]
n_blocks = T // m
@@ -130,3 +110,67 @@ def hca_compress_production(
)
return compressed.bfloat16()
# ===========================================================================
# KV-1/KV-2: NVFP4 output variants — single kernel, no intermediate BF16
# ===========================================================================
def csa_compress_production_nvfp4(
kv_proj_out: torch.Tensor,
gate_proj_out: torch.Tensor,
position_bias: Optional[torch.Tensor],
kv_norm_weight: Optional[torch.Tensor],
m: int = 4,
) -> tuple:
"""CSA compress + NVFP4 quantize: single kernel, no intermediate BF16.
KV-1: Production path. Compressed KV stored as NVFP4.
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
"""
T = kv_proj_out.shape[0]
hd = kv_proj_out.shape[1] // 2
n_blocks = T // m
if n_blocks == 0:
dev = kv_proj_out.device
return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev),
torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev),
torch.zeros(0, dtype=torch.float32, device=dev))
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])
pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
return mod.csa_compress_reduce_quant(
kv_proj_out.contiguous(), gate_proj_out.contiguous(),
pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks)
def hca_compress_production_nvfp4(
kv_proj_out: torch.Tensor,
gate_proj_out: torch.Tensor,
position_bias: Optional[torch.Tensor],
kv_norm_weight: Optional[torch.Tensor],
m: int = 128,
) -> tuple:
"""HCA compress + NVFP4 quantize: single kernel, no intermediate BF16.
KV-2: Production path. Compressed KV stored as NVFP4.
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
"""
T = kv_proj_out.shape[0]
hd = kv_proj_out.shape[1]
n_blocks = T // m
if n_blocks == 0:
dev = kv_proj_out.device
return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev),
torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev),
torch.zeros(0, dtype=torch.float32, device=dev))
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])
pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
return mod.hca_compress_reduce_quant(
kv_proj_out.contiguous(), gate_proj_out.contiguous(),
pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks)

View File

@@ -0,0 +1,461 @@
/**
* FUSED CSA/HCA compress + RMSNorm + NVFP4 quantize kernels.
*
* KV-1/KV-2: Single kernel launch per compressed entry.
* The compressor produces FP32 values, applies kv_norm, then quantizes
* to NVFP4 (E2M1 data + E4M3 block scales + FP32 global scale) all in
* one kernel. No intermediate BF16 materialization.
*
* Shared memory budget per CTA (128 threads, hd=512):
* s_vals: hd * 4 = 2048 bytes (FP32 staging)
* s_nibbles: hd * 1 = 512 bytes (E2M1 nibbles)
* s_sq/s_amax/s_inv_rms: ~16 bytes (reduction scratch)
* Total: ~2576 bytes — well within 48KB
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.hpp>
#include <torch/extension.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
#include <cmath>
// ===========================================================================
// Shared utilities
// ===========================================================================
__device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_warps) {
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val;
__syncthreads();
float result = 0.0f;
if (threadIdx.x < 32) {
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f;
for (int offset = 16; offset > 0; offset >>= 1)
v += __shfl_down_sync(0xffffffff, v, offset);
result = v;
}
__syncthreads();
return result;
}
__device__ __forceinline__ float block_reduce_max(float val, float* smem, int n_warps) {
for (int offset = 16; offset > 0; offset >>= 1)
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val;
__syncthreads();
float result = 0.0f;
if (threadIdx.x < 32) {
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f;
for (int offset = 16; offset > 0; offset >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset));
result = v;
}
__syncthreads();
return result;
}
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
if (hs <= 4) return hs;
if (hs <= 5) return 4;
if (hs <= 7) return 5;
if (hs <= 10) return 6;
return 7;
}
// ===========================================================================
// CSA fused compress + norm + quantize
// ===========================================================================
__global__ void csa_compress_reduce_quant_kernel(
const float* __restrict__ kv_proj, // [T, 2*hd] FP32
const float* __restrict__ gate_proj, // [T, 2*hd] FP32
const float* __restrict__ position_bias, // [m, 2*hd] FP32 or nullptr
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr
uint8_t* __restrict__ out_fp4, // (n_blocks, hd/2) packed E2M1
uint8_t* __restrict__ out_sf, // (n_blocks, hd/16) E4M3 block scales
float* __restrict__ out_gsa, // (n_blocks,) FP32 global scale
int T, int hd, int m, int n_blocks
) {
int block_i = blockIdx.x;
int tid = threadIdx.x;
int n_threads = blockDim.x;
int kv_dim = 2 * hd;
int n_warps = n_threads / 32;
if (block_i >= n_blocks) return;
int n_tokens = (block_i > 0) ? 2 * m : m;
int prev_start = (block_i - 1) * m;
int cur_start = block_i * m;
int cols_per_thread = (hd + n_threads - 1) / n_threads;
// ---- Phase 1: Softmax + weighted sum ----
float local_vals[4], local_max[4], local_denom[4], local_acc[4];
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
local_max[ci] = -FLT_MAX;
local_denom[ci] = 0.0f;
local_acc[ci] = 0.0f;
// Pass 1: max gate
for (int t = 0; t < n_tokens; t++) {
int token_idx, gate_offset;
if (block_i > 0) {
if (t < m) { token_idx = prev_start + t; gate_offset = 0; }
else { token_idx = cur_start + (t - m); gate_offset = hd; }
} else { token_idx = t; gate_offset = hd; }
if (token_idx < 0 || token_idx >= T) continue;
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
if (position_bias) {
int pbr = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
if (pbr >= 0 && pbr < m) g += position_bias[pbr * kv_dim + gate_offset + c];
}
local_max[ci] = fmaxf(local_max[ci], g);
}
// Pass 2: exp + weighted sum
for (int t = 0; t < n_tokens; t++) {
int token_idx, kv_offset, gate_offset;
if (block_i > 0) {
if (t < m) { token_idx = prev_start + t; kv_offset = 0; gate_offset = 0; }
else { token_idx = cur_start + (t - m); kv_offset = hd; gate_offset = hd; }
} else { token_idx = t; kv_offset = hd; gate_offset = hd; }
if (token_idx < 0 || token_idx >= T) continue;
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c];
if (position_bias) {
int pbr = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
if (pbr >= 0 && pbr < m) {
float pb = position_bias[pbr * kv_dim + gate_offset + c];
g += pb;
kv_val += position_bias[pbr * kv_dim + kv_offset + c];
}
}
float e = expf(g - local_max[ci]);
local_denom[ci] += e;
local_acc[ci] += e * kv_val;
}
local_vals[ci] = (local_denom[ci] > 0.0f) ? (local_acc[ci] / local_denom[ci]) : 0.0f;
}
// ---- Phase 2: kv_norm (RMSNorm) ----
if (kv_norm_weight) {
float local_sq = 0.0f;
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
local_sq += local_vals[ci] * local_vals[ci];
}
__shared__ float s_sq;
float total_sq = block_reduce_sum(local_sq, &s_sq, n_warps);
__shared__ float s_inv_rms;
if (tid == 0) s_inv_rms = rsqrtf(total_sq / hd + 1e-6f);
__syncthreads();
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
local_vals[ci] *= s_inv_rms * kv_norm_weight[c];
}
}
// ---- Phase 3: Global scale (gsa) ----
float entry_amax = 0.0f;
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
entry_amax = fmaxf(entry_amax, fabsf(local_vals[ci]));
}
__shared__ float s_amax;
float global_amax = block_reduce_max(entry_amax, &s_amax, n_warps);
float gsa = fmaxf(global_amax, 1e-8f) / (6.0f * 448.0f);
if (tid == 0) out_gsa[block_i] = gsa;
// ---- Phase 4: NVFP4 quantize via shared memory ----
__shared__ float s_vals[512]; // FP32 staging
__shared__ uint8_t s_nib[512]; // E2M1 nibbles
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
s_vals[c] = local_vals[ci];
}
__syncthreads();
int n_fp4_blocks = hd / 16;
int tpb = n_threads / n_fp4_blocks; // threads per fp4 block
int my_b = tid / tpb;
int my_l = tid % tpb;
if (my_b < n_fp4_blocks) {
int base = my_b * 16;
// Block amax
float bamax = 0.0f;
for (int i = my_l; i < 16; i += tpb) {
int c = base + i;
if (c < hd) bamax = fmaxf(bamax, fabsf(s_vals[c]) / gsa);
}
for (int off = tpb / 2; off > 0; off >>= 1)
bamax = fmaxf(bamax, __shfl_down_sync(0xffffffff, bamax, off));
float fbamax = __shfl_sync(0xffffffff, bamax, 0);
float bsf = fbamax / 6.0f;
bool zero_blk = (fbamax < 6.0f * 0.001953125f);
if (my_l == 0) {
if (zero_blk) {
out_sf[block_i * (hd / 16) + my_b] = 0;
} else {
__nv_fp8_e4m3 obj(bsf);
out_sf[block_i * (hd / 16) + my_b] = *(uint8_t*)&obj;
}
}
// Quantize to E2M1 nibbles
for (int i = my_l; i < 16; i += tpb) {
int c = base + i;
if (c >= hd || zero_blk) { s_nib[c] = 0; continue; }
float s = (s_vals[c] / gsa) / bsf;
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
if (hs > 12) hs = 12;
int idx = half_step_to_e2m1(hs);
if (s < 0) idx += 8;
s_nib[c] = idx;
}
}
__syncthreads();
// Pack and write
if (my_b < n_fp4_blocks && my_l == 0) {
int base = my_b * 16;
for (int i = 0; i < 8; i++) {
uint8_t lo = s_nib[base + 2 * i] & 0x0F;
uint8_t hi = s_nib[base + 2 * i + 1] & 0x0F;
out_fp4[block_i * (hd / 2) + my_b * 8 + i] = (hi << 4) | lo;
}
}
}
// ===========================================================================
// HCA fused compress + norm + quantize (simpler — no overlap)
// ===========================================================================
__global__ void hca_compress_reduce_quant_kernel(
const float* __restrict__ kv_proj,
const float* __restrict__ gate_proj,
const float* __restrict__ position_bias,
const float* __restrict__ kv_norm_weight,
uint8_t* __restrict__ out_fp4,
uint8_t* __restrict__ out_sf,
float* __restrict__ out_gsa,
int T, int hd, int m, int n_blocks
) {
int block_i = blockIdx.x;
int tid = threadIdx.x;
int n_threads = blockDim.x;
int n_warps = n_threads / 32;
if (block_i >= n_blocks) return;
int cols_per_thread = (hd + n_threads - 1) / n_threads;
// Phase 1: Softmax + weighted sum
float local_vals[4];
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
float lmax = -FLT_MAX, ldenom = 0.0f, lacc = 0.0f;
int start = block_i * m;
for (int t = 0; t < m; t++) {
int ti = start + t;
if (ti >= T) break;
float g = gate_proj[ti * hd + c];
if (position_bias && t < m) g += position_bias[t * hd + c];
lmax = fmaxf(lmax, g);
}
for (int t = 0; t < m; t++) {
int ti = start + t;
if (ti >= T) break;
float g = gate_proj[ti * hd + c];
float kv = kv_proj[ti * hd + c];
if (position_bias && t < m) { float pb = position_bias[t * hd + c]; g += pb; kv += pb; }
float e = expf(g - lmax);
ldenom += e; lacc += e * kv;
}
local_vals[ci] = (ldenom > 0.0f) ? (lacc / ldenom) : 0.0f;
}
// Phase 2: kv_norm
if (kv_norm_weight) {
float lsq = 0.0f;
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
lsq += local_vals[ci] * local_vals[ci];
}
__shared__ float s_sq;
float tsq = block_reduce_sum(lsq, &s_sq, n_warps);
__shared__ float s_inv_rms;
if (tid == 0) s_inv_rms = rsqrtf(tsq / hd + 1e-6f);
__syncthreads();
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
local_vals[ci] *= s_inv_rms * kv_norm_weight[c];
}
}
// Phase 3: gsa
float eamax = 0.0f;
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
eamax = fmaxf(eamax, fabsf(local_vals[ci]));
}
__shared__ float s_amax;
float gamax = block_reduce_max(eamax, &s_amax, n_warps);
float gsa = fmaxf(gamax, 1e-8f) / (6.0f * 448.0f);
if (tid == 0) out_gsa[block_i] = gsa;
// Phase 4: NVFP4 quantize
__shared__ float s_vals[512];
__shared__ uint8_t s_nib[512];
for (int ci = 0; ci < cols_per_thread; ci++) {
int c = tid + ci * n_threads;
if (c >= hd) break;
s_vals[c] = local_vals[ci];
}
__syncthreads();
int nfb = hd / 16;
int tpb = n_threads / nfb;
int my_b = tid / tpb;
int my_l = tid % tpb;
if (my_b < nfb) {
int base = my_b * 16;
float bamax = 0.0f;
for (int i = my_l; i < 16; i += tpb) {
int c = base + i;
if (c < hd) bamax = fmaxf(bamax, fabsf(s_vals[c]) / gsa);
}
for (int off = tpb / 2; off > 0; off >>= 1)
bamax = fmaxf(bamax, __shfl_down_sync(0xffffffff, bamax, off));
float fbamax = __shfl_sync(0xffffffff, bamax, 0);
float bsf = fbamax / 6.0f;
bool zblk = (fbamax < 6.0f * 0.001953125f);
if (my_l == 0) {
if (zblk) { out_sf[block_i * (hd / 16) + my_b] = 0; }
else { __nv_fp8_e4m3 obj(bsf); out_sf[block_i * (hd / 16) + my_b] = *(uint8_t*)&obj; }
}
for (int i = my_l; i < 16; i += tpb) {
int c = base + i;
if (c >= hd || zblk) { s_nib[c] = 0; continue; }
float s = (s_vals[c] / gsa) / bsf;
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
if (hs > 12) hs = 12;
int idx = half_step_to_e2m1(hs);
if (s < 0) idx += 8;
s_nib[c] = idx;
}
}
__syncthreads();
if (my_b < nfb && my_l == 0) {
int base = my_b * 16;
for (int i = 0; i < 8; i++) {
uint8_t lo = s_nib[base + 2 * i] & 0x0F;
uint8_t hi = s_nib[base + 2 * i + 1] & 0x0F;
out_fp4[block_i * (hd / 2) + my_b * 8 + i] = (hi << 4) | lo;
}
}
}
// ===========================================================================
// PyTorch bindings
// ===========================================================================
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
csa_compress_reduce_quant_cuda(
torch::Tensor kv_proj, // [T, 2*hd] FP32
torch::Tensor gate_proj, // [T, 2*hd] FP32
torch::Tensor position_bias, // [m, 2*hd] FP32 or empty
torch::Tensor kv_norm_weight, // [hd] FP32 or empty
int64_t m, int64_t n_blocks
) {
int T = kv_proj.size(0);
int hd = kv_proj.size(1) / 2;
int threads = 128;
const float* pos_ptr = (position_bias.numel() > 0) ? position_bias.data_ptr<float>() : nullptr;
const float* norm_ptr = (kv_norm_weight.numel() > 0) ? kv_norm_weight.data_ptr<float>() : nullptr;
auto opts = kv_proj.options();
auto out_fp4 = torch::zeros({(int)n_blocks, hd / 2}, opts.dtype(torch::kUInt8));
auto out_sf = torch::zeros({(int)n_blocks, hd / 16}, opts.dtype(torch::kUInt8));
auto out_gsa = torch::zeros({(int)n_blocks}, opts.dtype(torch::kFloat32));
csa_compress_reduce_quant_kernel<<<n_blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
kv_proj.data_ptr<float>(),
gate_proj.data_ptr<float>(),
pos_ptr, norm_ptr,
out_fp4.data_ptr<uint8_t>(),
out_sf.data_ptr<uint8_t>(),
out_gsa.data_ptr<float>(),
T, hd, (int)m, (int)n_blocks
);
C10_CUDA_CHECK(cudaGetLastError());
return {out_fp4.view(torch::kFloat4_e2m1fn_x2),
out_sf.view(torch::kFloat8_e4m3fn),
out_gsa};
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
hca_compress_reduce_quant_cuda(
torch::Tensor kv_proj,
torch::Tensor gate_proj,
torch::Tensor position_bias,
torch::Tensor kv_norm_weight,
int64_t m, int64_t n_blocks
) {
int T = kv_proj.size(0);
int hd = kv_proj.size(1);
int threads = 128;
const float* pos_ptr = (position_bias.numel() > 0) ? position_bias.data_ptr<float>() : nullptr;
const float* norm_ptr = (kv_norm_weight.numel() > 0) ? kv_norm_weight.data_ptr<float>() : nullptr;
auto opts = kv_proj.options();
auto out_fp4 = torch::zeros({(int)n_blocks, hd / 2}, opts.dtype(torch::kUInt8));
auto out_sf = torch::zeros({(int)n_blocks, hd / 16}, opts.dtype(torch::kUInt8));
auto out_gsa = torch::zeros({(int)n_blocks}, opts.dtype(torch::kFloat32));
hca_compress_reduce_quant_kernel<<<n_blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
kv_proj.data_ptr<float>(),
gate_proj.data_ptr<float>(),
pos_ptr, norm_ptr,
out_fp4.data_ptr<uint8_t>(),
out_sf.data_ptr<uint8_t>(),
out_gsa.data_ptr<float>(),
T, hd, (int)m, (int)n_blocks
);
C10_CUDA_CHECK(cudaGetLastError());
return {out_fp4.view(torch::kFloat4_e2m1fn_x2),
out_sf.view(torch::kFloat8_e4m3fn),
out_gsa};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("csa_compress_reduce_quant", &csa_compress_reduce_quant_cuda,
"Fused CSA compress + norm + NVFP4 quantize");
m.def("hca_compress_reduce_quant", &hca_compress_reduce_quant_cuda,
"Fused HCA compress + norm + NVFP4 quantize");
}

View File

@@ -0,0 +1,192 @@
/**
* NVFP4 → BF16 dequantization kernels.
*
* Converts FP4 (E2M1) data + FP8 (E4M3) block scales + FP32 global scales
* back to BF16. Used for the FMHA gather path: compressed KV is stored as
* NVFP4, and dequantized on-the-fly when gathering for attention.
*
* Two variants:
* 1. Full dequant: entire FP4 buffer → BF16 (for HCA dense gather)
* 2. Selective dequant: only selected rows → BF16 (for CSA top-k gather)
*
* Grid layout: (N/16, M) — one CTA per (row, 16-element block).
* Block size: 16 threads (1 thread per element in the 16-wide block).
*
* Memory savings: FP4 is 4× smaller than BF16. At hd=512:
* BF16: 512 × 2 = 1024 bytes per entry
* NVFP4: 256 + 64 + 4 = 324 bytes per entry (fp4 + sf + gsa)
* Savings: ~3.2×
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.hpp>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdint>
// E2M1 magnitudes: index 0-7 → 0, 0.5, 1, 1.5, 2, 3, 4, 6
__device__ __constant__ float E2M1_LUT[8] = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f};
// ===========================================================================
// Full dequant: entire buffer → BF16
// ===========================================================================
__global__ void dequant_nvfp4_kernel(
const uint8_t* __restrict__ fp4_data, // (M, N/2) packed E2M1
const uint8_t* __restrict__ sf_data, // (M, N/16) E4M3 block scales (stored as uint8)
const float* __restrict__ gsa_data, // (M,) FP32 global scale per row
__nv_bfloat16* __restrict__ output, // (M, N) BF16 output
int M, int N
) {
int m = blockIdx.y;
int n_block = blockIdx.x;
if (m >= M || n_block * 16 >= N) return;
float gsa = gsa_data[m];
// Read FP8 E4M3 block scale
uint8_t sf_byte = sf_data[m * (N / 16) + n_block];
__nv_fp8_e4m3 sf_val;
memcpy(&sf_val, &sf_byte, 1);
float bsf = (float)sf_val;
// Read 8 packed bytes = 16 E2M1 values
for (int i = 0; i < 8; i++) {
uint8_t packed = fp4_data[m * (N / 2) + n_block * 8 + i];
uint8_t lo_nibble = packed & 0x0F;
uint8_t hi_nibble = (packed >> 4) & 0x0F;
// Low nibble
int lo_idx = lo_nibble & 0x07;
float lo_sign = (lo_nibble & 0x08) ? -1.0f : 1.0f;
float lo_val = lo_sign * E2M1_LUT[lo_idx] * bsf * gsa;
int lo_col = n_block * 16 + 2 * i;
if (lo_col < N) {
output[m * N + lo_col] = __float2bfloat16(lo_val);
}
// High nibble
int hi_idx = hi_nibble & 0x07;
float hi_sign = (hi_nibble & 0x08) ? -1.0f : 1.0f;
float hi_val = hi_sign * E2M1_LUT[hi_idx] * bsf * gsa;
int hi_col = n_block * 16 + 2 * i + 1;
if (hi_col < N) {
output[m * N + hi_col] = __float2bfloat16(hi_val);
}
}
}
// ===========================================================================
// Selective dequant: only dequant selected rows from a larger FP4 buffer
// This is the CSA gather path — dequant only the top-k entries needed by FMHA
// ===========================================================================
__global__ void dequant_nvfp4_selective_kernel(
const uint8_t* __restrict__ fp4_data, // (max_comp, N/2) packed E2M1
const uint8_t* __restrict__ sf_data, // (max_comp, N/16) E4M3 block scales
const float* __restrict__ gsa_data, // (max_comp,) FP32 global scale per row
const int32_t* __restrict__ indices, // (K,) int32 — which rows to dequant
__nv_bfloat16* __restrict__ output, // (K, N) BF16 output
int K, int N
) {
int k = blockIdx.y; // which selected entry
int n_block = blockIdx.x; // which 16-element block
if (k >= K || n_block * 16 >= N) return;
int src_row = indices[k];
float gsa = gsa_data[src_row];
int N_half = N / 2;
int N_sf = N / 16;
// Read FP8 E4M3 block scale for this row and block
uint8_t sf_byte = sf_data[src_row * N_sf + n_block];
__nv_fp8_e4m3 sf_val;
memcpy(&sf_val, &sf_byte, 1);
float bsf = (float)sf_val;
for (int i = 0; i < 8; i++) {
uint8_t packed = fp4_data[src_row * N_half + n_block * 8 + i];
uint8_t lo_nibble = packed & 0x0F;
uint8_t hi_nibble = (packed >> 4) & 0x0F;
int lo_idx = lo_nibble & 0x07;
float lo_sign = (lo_nibble & 0x08) ? -1.0f : 1.0f;
float lo_val = lo_sign * E2M1_LUT[lo_idx] * bsf * gsa;
int lo_col = n_block * 16 + 2 * i;
if (lo_col < N) {
output[k * N + lo_col] = __float2bfloat16(lo_val);
}
int hi_idx = hi_nibble & 0x07;
float hi_sign = (hi_nibble & 0x08) ? -1.0f : 1.0f;
float hi_val = hi_sign * E2M1_LUT[hi_idx] * bsf * gsa;
int hi_col = n_block * 16 + 2 * i + 1;
if (hi_col < N) {
output[k * N + hi_col] = __float2bfloat16(hi_val);
}
}
}
// ===========================================================================
// PyTorch bindings
// ===========================================================================
torch::Tensor dequant_nvfp4_cuda(
torch::Tensor fp4_data, // (M, N/2) uint8 packed E2M1
torch::Tensor sf_data, // (M, N/16) uint8 (viewed as E4M3)
torch::Tensor gsa_data // (M,) float32 global scale
) {
int M = fp4_data.size(0);
int N = fp4_data.size(1) * 2; // N/2 packed → N actual
TORCH_CHECK(sf_data.size(0) == M, "sf_data row count must match fp4_data");
TORCH_CHECK(gsa_data.size(0) == M, "gsa_data row count must match fp4_data");
auto output = torch::zeros({M, N}, fp4_data.options().dtype(torch::kBFloat16));
int nb = N / 16;
dim3 grid(nb, M);
dim3 block(16);
dequant_nvfp4_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
fp4_data.data_ptr<uint8_t>(),
sf_data.data_ptr<uint8_t>(),
gsa_data.data_ptr<float>(),
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()),
M, N
);
return output;
}
torch::Tensor dequant_nvfp4_selective_cuda(
torch::Tensor fp4_data, // (max_comp, N/2) uint8 packed E2M1
torch::Tensor sf_data, // (max_comp, N/16) uint8 (viewed as E4M3)
torch::Tensor gsa_data, // (max_comp,) float32 global scale
torch::Tensor indices // (K,) int32
) {
int K = indices.size(0);
int N = fp4_data.size(1) * 2; // N/2 packed → N actual
TORCH_CHECK(indices.scalar_type() == torch::kInt32, "indices must be int32");
auto output = torch::zeros({K, N}, fp4_data.options().dtype(torch::kBFloat16));
int nb = N / 16;
dim3 grid(nb, K);
dim3 block(16);
dequant_nvfp4_selective_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
fp4_data.data_ptr<uint8_t>(),
sf_data.data_ptr<uint8_t>(),
gsa_data.data_ptr<float>(),
indices.data_ptr<int32_t>(),
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()),
K, N
);
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dequant_nvfp4", &dequant_nvfp4_cuda, "NVFP4 → BF16 dequant");
m.def("dequant_nvfp4_selective", &dequant_nvfp4_selective_cuda, "Selective NVFP4 → BF16 dequant for CSA gather");
}

View File

@@ -75,3 +75,7 @@ def preload_all():
get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"])
# Sampler
get_cuda_module("sampler", ["sampler.cu"])
# Dequant NVFP4
get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
# Fused compress + quantize
get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])

View File

@@ -0,0 +1,170 @@
#!/usr/bin/env python3
"""Test KV-1/KV-2: Fused compress + NVFP4 quantize kernel.
Verifies that the single-kernel compress+quantize path produces output
with cos >= 0.999 vs the BF16 reference path.
Production values:
- hd=512, m=4 (CSA), m=128 (HCA)
- T=32 (CSA: 8 blocks), T=256 (HCA: 2 blocks)
- kv_dim=1024 (CSA: 2*hd), kv_dim=512 (HCA: hd)
"""
import torch
import math
def test_csa_compress_quant():
"""KV-1: CSA compress + NVFP4 quantize vs BF16 reference."""
torch.manual_seed(42)
device = 'cuda'
hd = 512
m = 4
T = 32 # 8 compressed blocks
kv_dim = 2 * hd # CSA uses 2*hd for kv/gate projections
kv_proj = torch.randn(T, kv_dim, device=device) * 0.5
gate_proj = torch.randn(T, kv_dim, device=device) * 0.3
position_bias = torch.randn(m, kv_dim, device=device) * 0.1
kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5
# BF16 reference path
from dsv4.kernels.compressor.production_compress import csa_compress_production
ref_bf16 = csa_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
# NVFP4 fused path
from dsv4.kernels.compressor.production_compress import csa_compress_production_nvfp4
fp4_data, sf, gsa = csa_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
# Dequant NVFP4 → BF16
from dsv4.kernels.cuda.loader import get_cuda_module
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
nvfp4_bf16 = dequant_mod.dequant_nvfp4(
fp4_data.view(torch.uint8),
sf.view(torch.uint8),
gsa,
)
# Compare
ref_f = ref_bf16.float()
nvfp4_f = nvfp4_bf16.float()
cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item()
max_err = (ref_f - nvfp4_f).abs().max().item()
ref_max = ref_f.abs().max().item()
print(f"CSA compress + NVFP4 quantize:")
print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}")
print(f" fp4 shape: {tuple(fp4_data.shape)}, sf shape: {tuple(sf.shape)}, gsa shape: {tuple(gsa.shape)}")
print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}")
print(f" max_error: {max_err:.6f}")
print(f" cosine: {cos:.6f}")
assert cos >= 0.999, f"CSA compress+quant cos={cos:.6f} < 0.999"
print(f" ✅ PASS (cos={cos:.6f})")
def test_hca_compress_quant():
"""KV-2: HCA compress + NVFP4 quantize vs BF16 reference."""
torch.manual_seed(42)
device = 'cuda'
hd = 512
m = 128
T = 256 # 2 compressed blocks
kv_proj = torch.randn(T, hd, device=device) * 0.5
gate_proj = torch.randn(T, hd, device=device) * 0.3
position_bias = torch.randn(m, hd, device=device) * 0.1
kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5
# BF16 reference path
from dsv4.kernels.compressor.production_compress import hca_compress_production
ref_bf16 = hca_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
# NVFP4 fused path
from dsv4.kernels.compressor.production_compress import hca_compress_production_nvfp4
fp4_data, sf, gsa = hca_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
# Dequant NVFP4 → BF16
from dsv4.kernels.cuda.loader import get_cuda_module
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
nvfp4_bf16 = dequant_mod.dequant_nvfp4(
fp4_data.view(torch.uint8),
sf.view(torch.uint8),
gsa,
)
# Compare
ref_f = ref_bf16.float()
nvfp4_f = nvfp4_bf16.float()
cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item()
max_err = (ref_f - nvfp4_f).abs().max().item()
ref_max = ref_f.abs().max().item()
print(f"HCA compress + NVFP4 quantize:")
print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}")
print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}")
print(f" max_error: {max_err:.6f}")
print(f" cosine: {cos:.6f}")
assert cos >= 0.999, f"HCA compress+quant cos={cos:.6f} < 0.999"
print(f" ✅ PASS (cos={cos:.6f})")
def test_dequant_selective():
"""Test selective dequant: only top-k entries from a larger FP4 buffer."""
torch.manual_seed(42)
device = 'cuda'
M = 64 # total entries in cache
N = 512 # hd
K = 8 # top-k
# Create BF16 data
bf16_data = torch.randn(M, N, device=device, dtype=torch.bfloat16) * 2.0
# Quantize to NVFP4
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
fp4, sf, gsa = quantize_nvfp4_gpu_fused(bf16_data)
# Select K random indices
indices = torch.randperm(M, device=device)[:K].to(torch.int32)
# Selective dequant
from dsv4.kernels.cuda.loader import get_cuda_module
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
sel_bf16 = dequant_mod.dequant_nvfp4_selective(
fp4.view(torch.uint8),
sf.view(torch.uint8),
gsa,
indices,
)
# Full dequant for comparison
full_bf16 = dequant_mod.dequant_nvfp4(
fp4.view(torch.uint8),
sf.view(torch.uint8),
gsa,
)
# Compare selected entries
ref = full_bf16[indices.cpu().numpy()].to(device)
cos = torch.nn.functional.cosine_similarity(sel_bf16.float().flatten(), ref.float().flatten(), dim=0).item()
print(f"Selective dequant (M={M}, K={K}, N={N}):")
print(f" sel shape: {tuple(sel_bf16.shape)}")
print(f" cosine vs full dequant: {cos:.6f}")
assert cos >= 0.9999, f"Selective dequant cos={cos:.6f} < 0.9999"
print(f" ✅ PASS (cos={cos:.6f})")
if __name__ == "__main__":
print("=" * 60)
print("KV-1/KV-2: Compress + NVFP4 Quantize Tests")
print("Production values: hd=512, m=4 (CSA), m=128 (HCA)")
print("=" * 60)
test_csa_compress_quant()
print()
test_hca_compress_quant()
print()
test_dequant_selective()
print("\n" + "=" * 60)
print("ALL TESTS PASSED")
print("=" * 60)