diff --git a/dsv4/kernels/compressor/production_compress.py b/dsv4/kernels/compressor/production_compress.py index 1060f386..9aa52a53 100644 --- a/dsv4/kernels/compressor/production_compress.py +++ b/dsv4/kernels/compressor/production_compress.py @@ -6,6 +6,9 @@ Pipeline: 3. CUDA kernel: token-level softmax(gate) * kv → compressed entries 4. CUDA kernel: kv_norm (unweighted RMSNorm + weight) +KV-1/KV-2: NVFP4 output variants compress + quantize in a single kernel. +No intermediate BF16. Stored as FP4 data + E4M3 block scales + FP32 global scale. + No PyTorch softmax. No reference fallback. All on the GPU. """ @@ -40,18 +43,7 @@ def csa_compress_production( kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None m: int = 4, ) -> torch.Tensor: - """CSA compress: softmax + weighted sum + kv_norm. - - Args: - kv_proj_out: FP32 projection output, (T, 2*hd), Ca in first hd cols, Cb in second - gate_proj_out: FP32 projection output, (T, 2*hd), Ga in first hd cols, Gb in second - position_bias: (m, 2*hd) BF16 position bias, or None - kv_norm_weight: (hd) BF16 norm weight, or None - m: compression ratio (4 for CSA) - - Returns: - compressed: (n_blocks, hd) BF16 - """ + """CSA compress: softmax + weighted sum + kv_norm. Returns BF16.""" T = kv_proj_out.shape[0] hd = kv_proj_out.shape[1] // 2 n_blocks = T // m @@ -60,7 +52,6 @@ def csa_compress_production( mod = _get_kernel() - # Convert position_bias and kv_norm_weight to FP32 pos_bias_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device) if position_bias is not None: pos_bias_f32 = position_bias.float() @@ -90,18 +81,7 @@ def hca_compress_production( kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None m: int = 128, ) -> torch.Tensor: - """HCA compress: softmax + weighted sum + kv_norm. - - Args: - kv_proj_out: FP32 projection output, (T, hd) - gate_proj_out: FP32 projection output, (T, hd) - position_bias: (m, hd) BF16 position bias, or None - kv_norm_weight: (hd) BF16 norm weight, or None - m: compression ratio (128 for HCA) - - Returns: - compressed: (n_blocks, hd) BF16 - """ + """HCA compress: softmax + weighted sum + kv_norm. Returns BF16.""" T = kv_proj_out.shape[0] hd = kv_proj_out.shape[1] n_blocks = T // m @@ -130,3 +110,67 @@ def hca_compress_production( ) return compressed.bfloat16() + + +# =========================================================================== +# KV-1/KV-2: NVFP4 output variants — single kernel, no intermediate BF16 +# =========================================================================== + +def csa_compress_production_nvfp4( + kv_proj_out: torch.Tensor, + gate_proj_out: torch.Tensor, + position_bias: Optional[torch.Tensor], + kv_norm_weight: Optional[torch.Tensor], + m: int = 4, +) -> tuple: + """CSA compress + NVFP4 quantize: single kernel, no intermediate BF16. + + KV-1: Production path. Compressed KV stored as NVFP4. + Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple. + """ + T = kv_proj_out.shape[0] + hd = kv_proj_out.shape[1] // 2 + n_blocks = T // m + if n_blocks == 0: + dev = kv_proj_out.device + return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev), + torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev), + torch.zeros(0, dtype=torch.float32, device=dev)) + + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"]) + pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device) + norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device) + return mod.csa_compress_reduce_quant( + kv_proj_out.contiguous(), gate_proj_out.contiguous(), + pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks) + + +def hca_compress_production_nvfp4( + kv_proj_out: torch.Tensor, + gate_proj_out: torch.Tensor, + position_bias: Optional[torch.Tensor], + kv_norm_weight: Optional[torch.Tensor], + m: int = 128, +) -> tuple: + """HCA compress + NVFP4 quantize: single kernel, no intermediate BF16. + + KV-2: Production path. Compressed KV stored as NVFP4. + Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple. + """ + T = kv_proj_out.shape[0] + hd = kv_proj_out.shape[1] + n_blocks = T // m + if n_blocks == 0: + dev = kv_proj_out.device + return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev), + torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev), + torch.zeros(0, dtype=torch.float32, device=dev)) + + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"]) + pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device) + norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device) + return mod.hca_compress_reduce_quant( + kv_proj_out.contiguous(), gate_proj_out.contiguous(), + pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks) diff --git a/dsv4/kernels/cuda/compressor_reduce_quant.cu b/dsv4/kernels/cuda/compressor_reduce_quant.cu new file mode 100644 index 00000000..b3d5f798 --- /dev/null +++ b/dsv4/kernels/cuda/compressor_reduce_quant.cu @@ -0,0 +1,461 @@ +/** + * FUSED CSA/HCA compress + RMSNorm + NVFP4 quantize kernels. + * + * KV-1/KV-2: Single kernel launch per compressed entry. + * The compressor produces FP32 values, applies kv_norm, then quantizes + * to NVFP4 (E2M1 data + E4M3 block scales + FP32 global scale) all in + * one kernel. No intermediate BF16 materialization. + * + * Shared memory budget per CTA (128 threads, hd=512): + * s_vals: hd * 4 = 2048 bytes (FP32 staging) + * s_nibbles: hd * 1 = 512 bytes (E2M1 nibbles) + * s_sq/s_amax/s_inv_rms: ~16 bytes (reduction scratch) + * Total: ~2576 bytes — well within 48KB + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +// =========================================================================== +// Shared utilities +// =========================================================================== + +__device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_warps) { + for (int offset = 16; offset > 0; offset >>= 1) + val += __shfl_down_sync(0xffffffff, val, offset); + if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val; + __syncthreads(); + float result = 0.0f; + if (threadIdx.x < 32) { + float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f; + for (int offset = 16; offset > 0; offset >>= 1) + v += __shfl_down_sync(0xffffffff, v, offset); + result = v; + } + __syncthreads(); + return result; +} + +__device__ __forceinline__ float block_reduce_max(float val, float* smem, int n_warps) { + for (int offset = 16; offset > 0; offset >>= 1) + val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset)); + if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val; + __syncthreads(); + float result = 0.0f; + if (threadIdx.x < 32) { + float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f; + for (int offset = 16; offset > 0; offset >>= 1) + v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset)); + result = v; + } + __syncthreads(); + return result; +} + +__device__ __forceinline__ int half_step_to_e2m1(int hs) { + if (hs <= 4) return hs; + if (hs <= 5) return 4; + if (hs <= 7) return 5; + if (hs <= 10) return 6; + return 7; +} + +// =========================================================================== +// CSA fused compress + norm + quantize +// =========================================================================== + +__global__ void csa_compress_reduce_quant_kernel( + const float* __restrict__ kv_proj, // [T, 2*hd] FP32 + const float* __restrict__ gate_proj, // [T, 2*hd] FP32 + const float* __restrict__ position_bias, // [m, 2*hd] FP32 or nullptr + const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr + uint8_t* __restrict__ out_fp4, // (n_blocks, hd/2) packed E2M1 + uint8_t* __restrict__ out_sf, // (n_blocks, hd/16) E4M3 block scales + float* __restrict__ out_gsa, // (n_blocks,) FP32 global scale + int T, int hd, int m, int n_blocks +) { + int block_i = blockIdx.x; + int tid = threadIdx.x; + int n_threads = blockDim.x; + int kv_dim = 2 * hd; + int n_warps = n_threads / 32; + + if (block_i >= n_blocks) return; + + int n_tokens = (block_i > 0) ? 2 * m : m; + int prev_start = (block_i - 1) * m; + int cur_start = block_i * m; + int cols_per_thread = (hd + n_threads - 1) / n_threads; + + // ---- Phase 1: Softmax + weighted sum ---- + float local_vals[4], local_max[4], local_denom[4], local_acc[4]; + + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + local_max[ci] = -FLT_MAX; + local_denom[ci] = 0.0f; + local_acc[ci] = 0.0f; + + // Pass 1: max gate + for (int t = 0; t < n_tokens; t++) { + int token_idx, gate_offset; + if (block_i > 0) { + if (t < m) { token_idx = prev_start + t; gate_offset = 0; } + else { token_idx = cur_start + (t - m); gate_offset = hd; } + } else { token_idx = t; gate_offset = hd; } + if (token_idx < 0 || token_idx >= T) continue; + float g = gate_proj[token_idx * kv_dim + gate_offset + c]; + if (position_bias) { + int pbr = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t); + if (pbr >= 0 && pbr < m) g += position_bias[pbr * kv_dim + gate_offset + c]; + } + local_max[ci] = fmaxf(local_max[ci], g); + } + + // Pass 2: exp + weighted sum + for (int t = 0; t < n_tokens; t++) { + int token_idx, kv_offset, gate_offset; + if (block_i > 0) { + if (t < m) { token_idx = prev_start + t; kv_offset = 0; gate_offset = 0; } + else { token_idx = cur_start + (t - m); kv_offset = hd; gate_offset = hd; } + } else { token_idx = t; kv_offset = hd; gate_offset = hd; } + if (token_idx < 0 || token_idx >= T) continue; + float g = gate_proj[token_idx * kv_dim + gate_offset + c]; + float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c]; + if (position_bias) { + int pbr = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t); + if (pbr >= 0 && pbr < m) { + float pb = position_bias[pbr * kv_dim + gate_offset + c]; + g += pb; + kv_val += position_bias[pbr * kv_dim + kv_offset + c]; + } + } + float e = expf(g - local_max[ci]); + local_denom[ci] += e; + local_acc[ci] += e * kv_val; + } + local_vals[ci] = (local_denom[ci] > 0.0f) ? (local_acc[ci] / local_denom[ci]) : 0.0f; + } + + // ---- Phase 2: kv_norm (RMSNorm) ---- + if (kv_norm_weight) { + float local_sq = 0.0f; + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + local_sq += local_vals[ci] * local_vals[ci]; + } + __shared__ float s_sq; + float total_sq = block_reduce_sum(local_sq, &s_sq, n_warps); + __shared__ float s_inv_rms; + if (tid == 0) s_inv_rms = rsqrtf(total_sq / hd + 1e-6f); + __syncthreads(); + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + local_vals[ci] *= s_inv_rms * kv_norm_weight[c]; + } + } + + // ---- Phase 3: Global scale (gsa) ---- + float entry_amax = 0.0f; + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + entry_amax = fmaxf(entry_amax, fabsf(local_vals[ci])); + } + __shared__ float s_amax; + float global_amax = block_reduce_max(entry_amax, &s_amax, n_warps); + float gsa = fmaxf(global_amax, 1e-8f) / (6.0f * 448.0f); + if (tid == 0) out_gsa[block_i] = gsa; + + // ---- Phase 4: NVFP4 quantize via shared memory ---- + __shared__ float s_vals[512]; // FP32 staging + __shared__ uint8_t s_nib[512]; // E2M1 nibbles + + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + s_vals[c] = local_vals[ci]; + } + __syncthreads(); + + int n_fp4_blocks = hd / 16; + int tpb = n_threads / n_fp4_blocks; // threads per fp4 block + int my_b = tid / tpb; + int my_l = tid % tpb; + + if (my_b < n_fp4_blocks) { + int base = my_b * 16; + + // Block amax + float bamax = 0.0f; + for (int i = my_l; i < 16; i += tpb) { + int c = base + i; + if (c < hd) bamax = fmaxf(bamax, fabsf(s_vals[c]) / gsa); + } + for (int off = tpb / 2; off > 0; off >>= 1) + bamax = fmaxf(bamax, __shfl_down_sync(0xffffffff, bamax, off)); + float fbamax = __shfl_sync(0xffffffff, bamax, 0); + + float bsf = fbamax / 6.0f; + bool zero_blk = (fbamax < 6.0f * 0.001953125f); + + if (my_l == 0) { + if (zero_blk) { + out_sf[block_i * (hd / 16) + my_b] = 0; + } else { + __nv_fp8_e4m3 obj(bsf); + out_sf[block_i * (hd / 16) + my_b] = *(uint8_t*)&obj; + } + } + + // Quantize to E2M1 nibbles + for (int i = my_l; i < 16; i += tpb) { + int c = base + i; + if (c >= hd || zero_blk) { s_nib[c] = 0; continue; } + float s = (s_vals[c] / gsa) / bsf; + int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f); + if (hs > 12) hs = 12; + int idx = half_step_to_e2m1(hs); + if (s < 0) idx += 8; + s_nib[c] = idx; + } + } + __syncthreads(); + + // Pack and write + if (my_b < n_fp4_blocks && my_l == 0) { + int base = my_b * 16; + for (int i = 0; i < 8; i++) { + uint8_t lo = s_nib[base + 2 * i] & 0x0F; + uint8_t hi = s_nib[base + 2 * i + 1] & 0x0F; + out_fp4[block_i * (hd / 2) + my_b * 8 + i] = (hi << 4) | lo; + } + } +} + +// =========================================================================== +// HCA fused compress + norm + quantize (simpler — no overlap) +// =========================================================================== + +__global__ void hca_compress_reduce_quant_kernel( + const float* __restrict__ kv_proj, + const float* __restrict__ gate_proj, + const float* __restrict__ position_bias, + const float* __restrict__ kv_norm_weight, + uint8_t* __restrict__ out_fp4, + uint8_t* __restrict__ out_sf, + float* __restrict__ out_gsa, + int T, int hd, int m, int n_blocks +) { + int block_i = blockIdx.x; + int tid = threadIdx.x; + int n_threads = blockDim.x; + int n_warps = n_threads / 32; + + if (block_i >= n_blocks) return; + + int cols_per_thread = (hd + n_threads - 1) / n_threads; + + // Phase 1: Softmax + weighted sum + float local_vals[4]; + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + float lmax = -FLT_MAX, ldenom = 0.0f, lacc = 0.0f; + int start = block_i * m; + for (int t = 0; t < m; t++) { + int ti = start + t; + if (ti >= T) break; + float g = gate_proj[ti * hd + c]; + if (position_bias && t < m) g += position_bias[t * hd + c]; + lmax = fmaxf(lmax, g); + } + for (int t = 0; t < m; t++) { + int ti = start + t; + if (ti >= T) break; + float g = gate_proj[ti * hd + c]; + float kv = kv_proj[ti * hd + c]; + if (position_bias && t < m) { float pb = position_bias[t * hd + c]; g += pb; kv += pb; } + float e = expf(g - lmax); + ldenom += e; lacc += e * kv; + } + local_vals[ci] = (ldenom > 0.0f) ? (lacc / ldenom) : 0.0f; + } + + // Phase 2: kv_norm + if (kv_norm_weight) { + float lsq = 0.0f; + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + lsq += local_vals[ci] * local_vals[ci]; + } + __shared__ float s_sq; + float tsq = block_reduce_sum(lsq, &s_sq, n_warps); + __shared__ float s_inv_rms; + if (tid == 0) s_inv_rms = rsqrtf(tsq / hd + 1e-6f); + __syncthreads(); + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + local_vals[ci] *= s_inv_rms * kv_norm_weight[c]; + } + } + + // Phase 3: gsa + float eamax = 0.0f; + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + eamax = fmaxf(eamax, fabsf(local_vals[ci])); + } + __shared__ float s_amax; + float gamax = block_reduce_max(eamax, &s_amax, n_warps); + float gsa = fmaxf(gamax, 1e-8f) / (6.0f * 448.0f); + if (tid == 0) out_gsa[block_i] = gsa; + + // Phase 4: NVFP4 quantize + __shared__ float s_vals[512]; + __shared__ uint8_t s_nib[512]; + for (int ci = 0; ci < cols_per_thread; ci++) { + int c = tid + ci * n_threads; + if (c >= hd) break; + s_vals[c] = local_vals[ci]; + } + __syncthreads(); + + int nfb = hd / 16; + int tpb = n_threads / nfb; + int my_b = tid / tpb; + int my_l = tid % tpb; + + if (my_b < nfb) { + int base = my_b * 16; + float bamax = 0.0f; + for (int i = my_l; i < 16; i += tpb) { + int c = base + i; + if (c < hd) bamax = fmaxf(bamax, fabsf(s_vals[c]) / gsa); + } + for (int off = tpb / 2; off > 0; off >>= 1) + bamax = fmaxf(bamax, __shfl_down_sync(0xffffffff, bamax, off)); + float fbamax = __shfl_sync(0xffffffff, bamax, 0); + float bsf = fbamax / 6.0f; + bool zblk = (fbamax < 6.0f * 0.001953125f); + if (my_l == 0) { + if (zblk) { out_sf[block_i * (hd / 16) + my_b] = 0; } + else { __nv_fp8_e4m3 obj(bsf); out_sf[block_i * (hd / 16) + my_b] = *(uint8_t*)&obj; } + } + for (int i = my_l; i < 16; i += tpb) { + int c = base + i; + if (c >= hd || zblk) { s_nib[c] = 0; continue; } + float s = (s_vals[c] / gsa) / bsf; + int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f); + if (hs > 12) hs = 12; + int idx = half_step_to_e2m1(hs); + if (s < 0) idx += 8; + s_nib[c] = idx; + } + } + __syncthreads(); + + if (my_b < nfb && my_l == 0) { + int base = my_b * 16; + for (int i = 0; i < 8; i++) { + uint8_t lo = s_nib[base + 2 * i] & 0x0F; + uint8_t hi = s_nib[base + 2 * i + 1] & 0x0F; + out_fp4[block_i * (hd / 2) + my_b * 8 + i] = (hi << 4) | lo; + } + } +} + +// =========================================================================== +// PyTorch bindings +// =========================================================================== + +std::tuple +csa_compress_reduce_quant_cuda( + torch::Tensor kv_proj, // [T, 2*hd] FP32 + torch::Tensor gate_proj, // [T, 2*hd] FP32 + torch::Tensor position_bias, // [m, 2*hd] FP32 or empty + torch::Tensor kv_norm_weight, // [hd] FP32 or empty + int64_t m, int64_t n_blocks +) { + int T = kv_proj.size(0); + int hd = kv_proj.size(1) / 2; + int threads = 128; + + const float* pos_ptr = (position_bias.numel() > 0) ? position_bias.data_ptr() : nullptr; + const float* norm_ptr = (kv_norm_weight.numel() > 0) ? kv_norm_weight.data_ptr() : nullptr; + + auto opts = kv_proj.options(); + auto out_fp4 = torch::zeros({(int)n_blocks, hd / 2}, opts.dtype(torch::kUInt8)); + auto out_sf = torch::zeros({(int)n_blocks, hd / 16}, opts.dtype(torch::kUInt8)); + auto out_gsa = torch::zeros({(int)n_blocks}, opts.dtype(torch::kFloat32)); + + csa_compress_reduce_quant_kernel<<>>( + kv_proj.data_ptr(), + gate_proj.data_ptr(), + pos_ptr, norm_ptr, + out_fp4.data_ptr(), + out_sf.data_ptr(), + out_gsa.data_ptr(), + T, hd, (int)m, (int)n_blocks + ); + C10_CUDA_CHECK(cudaGetLastError()); + + return {out_fp4.view(torch::kFloat4_e2m1fn_x2), + out_sf.view(torch::kFloat8_e4m3fn), + out_gsa}; +} + +std::tuple +hca_compress_reduce_quant_cuda( + torch::Tensor kv_proj, + torch::Tensor gate_proj, + torch::Tensor position_bias, + torch::Tensor kv_norm_weight, + int64_t m, int64_t n_blocks +) { + int T = kv_proj.size(0); + int hd = kv_proj.size(1); + int threads = 128; + + const float* pos_ptr = (position_bias.numel() > 0) ? position_bias.data_ptr() : nullptr; + const float* norm_ptr = (kv_norm_weight.numel() > 0) ? kv_norm_weight.data_ptr() : nullptr; + + auto opts = kv_proj.options(); + auto out_fp4 = torch::zeros({(int)n_blocks, hd / 2}, opts.dtype(torch::kUInt8)); + auto out_sf = torch::zeros({(int)n_blocks, hd / 16}, opts.dtype(torch::kUInt8)); + auto out_gsa = torch::zeros({(int)n_blocks}, opts.dtype(torch::kFloat32)); + + hca_compress_reduce_quant_kernel<<>>( + kv_proj.data_ptr(), + gate_proj.data_ptr(), + pos_ptr, norm_ptr, + out_fp4.data_ptr(), + out_sf.data_ptr(), + out_gsa.data_ptr(), + T, hd, (int)m, (int)n_blocks + ); + C10_CUDA_CHECK(cudaGetLastError()); + + return {out_fp4.view(torch::kFloat4_e2m1fn_x2), + out_sf.view(torch::kFloat8_e4m3fn), + out_gsa}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("csa_compress_reduce_quant", &csa_compress_reduce_quant_cuda, + "Fused CSA compress + norm + NVFP4 quantize"); + m.def("hca_compress_reduce_quant", &hca_compress_reduce_quant_cuda, + "Fused HCA compress + norm + NVFP4 quantize"); +} diff --git a/dsv4/kernels/cuda/dequant_nvfp4.cu b/dsv4/kernels/cuda/dequant_nvfp4.cu new file mode 100644 index 00000000..a50edd9c --- /dev/null +++ b/dsv4/kernels/cuda/dequant_nvfp4.cu @@ -0,0 +1,192 @@ +/** + * NVFP4 → BF16 dequantization kernels. + * + * Converts FP4 (E2M1) data + FP8 (E4M3) block scales + FP32 global scales + * back to BF16. Used for the FMHA gather path: compressed KV is stored as + * NVFP4, and dequantized on-the-fly when gathering for attention. + * + * Two variants: + * 1. Full dequant: entire FP4 buffer → BF16 (for HCA dense gather) + * 2. Selective dequant: only selected rows → BF16 (for CSA top-k gather) + * + * Grid layout: (N/16, M) — one CTA per (row, 16-element block). + * Block size: 16 threads (1 thread per element in the 16-wide block). + * + * Memory savings: FP4 is 4× smaller than BF16. At hd=512: + * BF16: 512 × 2 = 1024 bytes per entry + * NVFP4: 256 + 64 + 4 = 324 bytes per entry (fp4 + sf + gsa) + * Savings: ~3.2× + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +// E2M1 magnitudes: index 0-7 → 0, 0.5, 1, 1.5, 2, 3, 4, 6 +__device__ __constant__ float E2M1_LUT[8] = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f}; + +// =========================================================================== +// Full dequant: entire buffer → BF16 +// =========================================================================== + +__global__ void dequant_nvfp4_kernel( + const uint8_t* __restrict__ fp4_data, // (M, N/2) packed E2M1 + const uint8_t* __restrict__ sf_data, // (M, N/16) E4M3 block scales (stored as uint8) + const float* __restrict__ gsa_data, // (M,) FP32 global scale per row + __nv_bfloat16* __restrict__ output, // (M, N) BF16 output + int M, int N +) { + int m = blockIdx.y; + int n_block = blockIdx.x; + if (m >= M || n_block * 16 >= N) return; + + float gsa = gsa_data[m]; + + // Read FP8 E4M3 block scale + uint8_t sf_byte = sf_data[m * (N / 16) + n_block]; + __nv_fp8_e4m3 sf_val; + memcpy(&sf_val, &sf_byte, 1); + float bsf = (float)sf_val; + + // Read 8 packed bytes = 16 E2M1 values + for (int i = 0; i < 8; i++) { + uint8_t packed = fp4_data[m * (N / 2) + n_block * 8 + i]; + uint8_t lo_nibble = packed & 0x0F; + uint8_t hi_nibble = (packed >> 4) & 0x0F; + + // Low nibble + int lo_idx = lo_nibble & 0x07; + float lo_sign = (lo_nibble & 0x08) ? -1.0f : 1.0f; + float lo_val = lo_sign * E2M1_LUT[lo_idx] * bsf * gsa; + int lo_col = n_block * 16 + 2 * i; + if (lo_col < N) { + output[m * N + lo_col] = __float2bfloat16(lo_val); + } + + // High nibble + int hi_idx = hi_nibble & 0x07; + float hi_sign = (hi_nibble & 0x08) ? -1.0f : 1.0f; + float hi_val = hi_sign * E2M1_LUT[hi_idx] * bsf * gsa; + int hi_col = n_block * 16 + 2 * i + 1; + if (hi_col < N) { + output[m * N + hi_col] = __float2bfloat16(hi_val); + } + } +} + +// =========================================================================== +// Selective dequant: only dequant selected rows from a larger FP4 buffer +// This is the CSA gather path — dequant only the top-k entries needed by FMHA +// =========================================================================== + +__global__ void dequant_nvfp4_selective_kernel( + const uint8_t* __restrict__ fp4_data, // (max_comp, N/2) packed E2M1 + const uint8_t* __restrict__ sf_data, // (max_comp, N/16) E4M3 block scales + const float* __restrict__ gsa_data, // (max_comp,) FP32 global scale per row + const int32_t* __restrict__ indices, // (K,) int32 — which rows to dequant + __nv_bfloat16* __restrict__ output, // (K, N) BF16 output + int K, int N +) { + int k = blockIdx.y; // which selected entry + int n_block = blockIdx.x; // which 16-element block + if (k >= K || n_block * 16 >= N) return; + + int src_row = indices[k]; + float gsa = gsa_data[src_row]; + + int N_half = N / 2; + int N_sf = N / 16; + + // Read FP8 E4M3 block scale for this row and block + uint8_t sf_byte = sf_data[src_row * N_sf + n_block]; + __nv_fp8_e4m3 sf_val; + memcpy(&sf_val, &sf_byte, 1); + float bsf = (float)sf_val; + + for (int i = 0; i < 8; i++) { + uint8_t packed = fp4_data[src_row * N_half + n_block * 8 + i]; + uint8_t lo_nibble = packed & 0x0F; + uint8_t hi_nibble = (packed >> 4) & 0x0F; + + int lo_idx = lo_nibble & 0x07; + float lo_sign = (lo_nibble & 0x08) ? -1.0f : 1.0f; + float lo_val = lo_sign * E2M1_LUT[lo_idx] * bsf * gsa; + int lo_col = n_block * 16 + 2 * i; + if (lo_col < N) { + output[k * N + lo_col] = __float2bfloat16(lo_val); + } + + int hi_idx = hi_nibble & 0x07; + float hi_sign = (hi_nibble & 0x08) ? -1.0f : 1.0f; + float hi_val = hi_sign * E2M1_LUT[hi_idx] * bsf * gsa; + int hi_col = n_block * 16 + 2 * i + 1; + if (hi_col < N) { + output[k * N + hi_col] = __float2bfloat16(hi_val); + } + } +} + +// =========================================================================== +// PyTorch bindings +// =========================================================================== + +torch::Tensor dequant_nvfp4_cuda( + torch::Tensor fp4_data, // (M, N/2) uint8 packed E2M1 + torch::Tensor sf_data, // (M, N/16) uint8 (viewed as E4M3) + torch::Tensor gsa_data // (M,) float32 global scale +) { + int M = fp4_data.size(0); + int N = fp4_data.size(1) * 2; // N/2 packed → N actual + TORCH_CHECK(sf_data.size(0) == M, "sf_data row count must match fp4_data"); + TORCH_CHECK(gsa_data.size(0) == M, "gsa_data row count must match fp4_data"); + + auto output = torch::zeros({M, N}, fp4_data.options().dtype(torch::kBFloat16)); + int nb = N / 16; + dim3 grid(nb, M); + dim3 block(16); + + dequant_nvfp4_kernel<<>>( + fp4_data.data_ptr(), + sf_data.data_ptr(), + gsa_data.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), + M, N + ); + return output; +} + +torch::Tensor dequant_nvfp4_selective_cuda( + torch::Tensor fp4_data, // (max_comp, N/2) uint8 packed E2M1 + torch::Tensor sf_data, // (max_comp, N/16) uint8 (viewed as E4M3) + torch::Tensor gsa_data, // (max_comp,) float32 global scale + torch::Tensor indices // (K,) int32 +) { + int K = indices.size(0); + int N = fp4_data.size(1) * 2; // N/2 packed → N actual + TORCH_CHECK(indices.scalar_type() == torch::kInt32, "indices must be int32"); + + auto output = torch::zeros({K, N}, fp4_data.options().dtype(torch::kBFloat16)); + int nb = N / 16; + dim3 grid(nb, K); + dim3 block(16); + + dequant_nvfp4_selective_kernel<<>>( + fp4_data.data_ptr(), + sf_data.data_ptr(), + gsa_data.data_ptr(), + indices.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), + K, N + ); + return output; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dequant_nvfp4", &dequant_nvfp4_cuda, "NVFP4 → BF16 dequant"); + m.def("dequant_nvfp4_selective", &dequant_nvfp4_selective_cuda, "Selective NVFP4 → BF16 dequant for CSA gather"); +} diff --git a/dsv4/kernels/cuda/loader.py b/dsv4/kernels/cuda/loader.py index 2f168f8c..3200e4ea 100644 --- a/dsv4/kernels/cuda/loader.py +++ b/dsv4/kernels/cuda/loader.py @@ -75,3 +75,7 @@ def preload_all(): get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"]) # Sampler get_cuda_module("sampler", ["sampler.cu"]) + # Dequant NVFP4 + get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) + # Fused compress + quantize + get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"]) diff --git a/tests/unit/test_kv_compress_quant.py b/tests/unit/test_kv_compress_quant.py new file mode 100644 index 00000000..efc143a6 --- /dev/null +++ b/tests/unit/test_kv_compress_quant.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +"""Test KV-1/KV-2: Fused compress + NVFP4 quantize kernel. + +Verifies that the single-kernel compress+quantize path produces output +with cos >= 0.999 vs the BF16 reference path. + +Production values: + - hd=512, m=4 (CSA), m=128 (HCA) + - T=32 (CSA: 8 blocks), T=256 (HCA: 2 blocks) + - kv_dim=1024 (CSA: 2*hd), kv_dim=512 (HCA: hd) +""" + +import torch +import math + +def test_csa_compress_quant(): + """KV-1: CSA compress + NVFP4 quantize vs BF16 reference.""" + torch.manual_seed(42) + device = 'cuda' + hd = 512 + m = 4 + T = 32 # 8 compressed blocks + kv_dim = 2 * hd # CSA uses 2*hd for kv/gate projections + + kv_proj = torch.randn(T, kv_dim, device=device) * 0.5 + gate_proj = torch.randn(T, kv_dim, device=device) * 0.3 + position_bias = torch.randn(m, kv_dim, device=device) * 0.1 + kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5 + + # BF16 reference path + from dsv4.kernels.compressor.production_compress import csa_compress_production + ref_bf16 = csa_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m) + + # NVFP4 fused path + from dsv4.kernels.compressor.production_compress import csa_compress_production_nvfp4 + fp4_data, sf, gsa = csa_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m) + + # Dequant NVFP4 → BF16 + from dsv4.kernels.cuda.loader import get_cuda_module + dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) + nvfp4_bf16 = dequant_mod.dequant_nvfp4( + fp4_data.view(torch.uint8), + sf.view(torch.uint8), + gsa, + ) + + # Compare + ref_f = ref_bf16.float() + nvfp4_f = nvfp4_bf16.float() + cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item() + max_err = (ref_f - nvfp4_f).abs().max().item() + ref_max = ref_f.abs().max().item() + + print(f"CSA compress + NVFP4 quantize:") + print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}") + print(f" fp4 shape: {tuple(fp4_data.shape)}, sf shape: {tuple(sf.shape)}, gsa shape: {tuple(gsa.shape)}") + print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}") + print(f" max_error: {max_err:.6f}") + print(f" cosine: {cos:.6f}") + assert cos >= 0.999, f"CSA compress+quant cos={cos:.6f} < 0.999" + print(f" ✅ PASS (cos={cos:.6f})") + + +def test_hca_compress_quant(): + """KV-2: HCA compress + NVFP4 quantize vs BF16 reference.""" + torch.manual_seed(42) + device = 'cuda' + hd = 512 + m = 128 + T = 256 # 2 compressed blocks + + kv_proj = torch.randn(T, hd, device=device) * 0.5 + gate_proj = torch.randn(T, hd, device=device) * 0.3 + position_bias = torch.randn(m, hd, device=device) * 0.1 + kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5 + + # BF16 reference path + from dsv4.kernels.compressor.production_compress import hca_compress_production + ref_bf16 = hca_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m) + + # NVFP4 fused path + from dsv4.kernels.compressor.production_compress import hca_compress_production_nvfp4 + fp4_data, sf, gsa = hca_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m) + + # Dequant NVFP4 → BF16 + from dsv4.kernels.cuda.loader import get_cuda_module + dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) + nvfp4_bf16 = dequant_mod.dequant_nvfp4( + fp4_data.view(torch.uint8), + sf.view(torch.uint8), + gsa, + ) + + # Compare + ref_f = ref_bf16.float() + nvfp4_f = nvfp4_bf16.float() + cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item() + max_err = (ref_f - nvfp4_f).abs().max().item() + ref_max = ref_f.abs().max().item() + + print(f"HCA compress + NVFP4 quantize:") + print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}") + print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}") + print(f" max_error: {max_err:.6f}") + print(f" cosine: {cos:.6f}") + assert cos >= 0.999, f"HCA compress+quant cos={cos:.6f} < 0.999" + print(f" ✅ PASS (cos={cos:.6f})") + + +def test_dequant_selective(): + """Test selective dequant: only top-k entries from a larger FP4 buffer.""" + torch.manual_seed(42) + device = 'cuda' + M = 64 # total entries in cache + N = 512 # hd + K = 8 # top-k + + # Create BF16 data + bf16_data = torch.randn(M, N, device=device, dtype=torch.bfloat16) * 2.0 + + # Quantize to NVFP4 + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + fp4, sf, gsa = quantize_nvfp4_gpu_fused(bf16_data) + + # Select K random indices + indices = torch.randperm(M, device=device)[:K].to(torch.int32) + + # Selective dequant + from dsv4.kernels.cuda.loader import get_cuda_module + dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) + sel_bf16 = dequant_mod.dequant_nvfp4_selective( + fp4.view(torch.uint8), + sf.view(torch.uint8), + gsa, + indices, + ) + + # Full dequant for comparison + full_bf16 = dequant_mod.dequant_nvfp4( + fp4.view(torch.uint8), + sf.view(torch.uint8), + gsa, + ) + + # Compare selected entries + ref = full_bf16[indices.cpu().numpy()].to(device) + cos = torch.nn.functional.cosine_similarity(sel_bf16.float().flatten(), ref.float().flatten(), dim=0).item() + + print(f"Selective dequant (M={M}, K={K}, N={N}):") + print(f" sel shape: {tuple(sel_bf16.shape)}") + print(f" cosine vs full dequant: {cos:.6f}") + assert cos >= 0.9999, f"Selective dequant cos={cos:.6f} < 0.9999" + print(f" ✅ PASS (cos={cos:.6f})") + + +if __name__ == "__main__": + print("=" * 60) + print("KV-1/KV-2: Compress + NVFP4 Quantize Tests") + print("Production values: hd=512, m=4 (CSA), m=128 (HCA)") + print("=" * 60) + + test_csa_compress_quant() + print() + test_hca_compress_quant() + print() + test_dequant_selective() + + print("\n" + "=" * 60) + print("ALL TESTS PASSED") + print("=" * 60)