diff --git a/dsv4/kernels/compressor/production_compress.py b/dsv4/kernels/compressor/production_compress.py index 9aa52a53..e9190135 100644 --- a/dsv4/kernels/compressor/production_compress.py +++ b/dsv4/kernels/compressor/production_compress.py @@ -44,11 +44,24 @@ def csa_compress_production( m: int = 4, ) -> torch.Tensor: """CSA compress: softmax + weighted sum + kv_norm. Returns BF16.""" + return csa_compress_production_fp32( + kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m + ).bfloat16() + + +def csa_compress_production_fp32( + kv_proj_out: torch.Tensor, + gate_proj_out: torch.Tensor, + position_bias: Optional[torch.Tensor], + kv_norm_weight: Optional[torch.Tensor], + m: int = 4, +) -> torch.Tensor: + """CSA compress: softmax + weighted sum + kv_norm. Returns FP32.""" T = kv_proj_out.shape[0] hd = kv_proj_out.shape[1] // 2 n_blocks = T // m if n_blocks == 0: - return torch.zeros(0, hd, dtype=torch.bfloat16, device=kv_proj_out.device) + return torch.zeros(0, hd, dtype=torch.float32, device=kv_proj_out.device) mod = _get_kernel() @@ -71,7 +84,7 @@ def csa_compress_production( m, n_blocks, ) - return compressed.bfloat16() + return compressed def hca_compress_production( @@ -82,11 +95,24 @@ def hca_compress_production( m: int = 128, ) -> torch.Tensor: """HCA compress: softmax + weighted sum + kv_norm. Returns BF16.""" + return hca_compress_production_fp32( + kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m + ).bfloat16() + + +def hca_compress_production_fp32( + kv_proj_out: torch.Tensor, + gate_proj_out: torch.Tensor, + position_bias: Optional[torch.Tensor], + kv_norm_weight: Optional[torch.Tensor], + m: int = 128, +) -> torch.Tensor: + """HCA compress: softmax + weighted sum + kv_norm. Returns FP32.""" T = kv_proj_out.shape[0] hd = kv_proj_out.shape[1] n_blocks = T // m if n_blocks == 0: - return torch.zeros(0, hd, dtype=torch.bfloat16, device=kv_proj_out.device) + return torch.zeros(0, hd, dtype=torch.float32, device=kv_proj_out.device) mod = _get_kernel() @@ -109,13 +135,43 @@ def hca_compress_production( m, n_blocks, ) - return compressed.bfloat16() + return compressed # =========================================================================== -# KV-1/KV-2: NVFP4 output variants — single kernel, no intermediate BF16 +# KV-1/KV-2: NVFP4 output — two proven kernels, no BF16 intermediate +# +# Architecture: +# 1. CUDA compress kernel (compressor_reduce.cu) → FP32 compressed output +# 2. CUDA amax_gsa_fp32 → per-row gsa (GPU-only, no CPU sync) +# 3. CUDA quantize_nvfp4_from_fp32 → NVFP4 triple (fp4 + sf + gsa) +# +# This is the same two-kernel pattern that works everywhere else in the +# pipeline (quantize_nvfp4_gpu_fused). The previous single-kernel fused +# approach had shared memory corruption bugs. Two kernels is correct. +# +# Storage: NVFP4 (E2M1 data + E4M3 block scales + FP32 global scale) +# Read path: dequant_nvfp4 / dequant_nvfp4_selective → BF16 for FMHA # =========================================================================== +def _quantize_fp32_to_nvfp4(compressed_fp32: torch.Tensor) -> tuple: + """Quantize FP32 compressed output → NVFP4. Two-kernel, GPU-only. + + Uses the same proven pattern as quantize_nvfp4_gpu_fused (amax_gsa + + quantize_from_buffer) but with FP32 input instead of BF16. + No BF16 intermediate. No CPU sync. + + Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple. + """ + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) + # Kernel 1: Compute per-row gsa from FP32 input (GPU-only) + gsa = mod.compute_amax_gsa_fp32(compressed_fp32.contiguous(), 6.0 * 448.0) + # Kernel 2: Quantize FP32 → NVFP4 using GPU gsa buffer + fp4, sf = mod.quantize_nvfp4_from_fp32(compressed_fp32.contiguous(), gsa) + return fp4, sf, gsa + + def csa_compress_production_nvfp4( kv_proj_out: torch.Tensor, gate_proj_out: torch.Tensor, @@ -123,27 +179,23 @@ def csa_compress_production_nvfp4( kv_norm_weight: Optional[torch.Tensor], m: int = 4, ) -> tuple: - """CSA compress + NVFP4 quantize: single kernel, no intermediate BF16. + """CSA compress → NVFP4. No BF16 intermediate. KV-1: Production path. Compressed KV stored as NVFP4. + Pipeline: compress (FP32) → amax_gsa (GPU) → quantize (GPU) → NVFP4 triple. Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple. """ - T = kv_proj_out.shape[0] - hd = kv_proj_out.shape[1] // 2 - n_blocks = T // m - if n_blocks == 0: + # Step 1: Compress → FP32 (same proven kernel as BF16 path) + compressed_fp32 = csa_compress_production_fp32( + kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m) + if compressed_fp32.shape[0] == 0: dev = kv_proj_out.device + hd = kv_proj_out.shape[1] // 2 return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev), torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev), torch.zeros(0, dtype=torch.float32, device=dev)) - - from dsv4.kernels.cuda.loader import get_cuda_module - mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"]) - pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device) - norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device) - return mod.csa_compress_reduce_quant( - kv_proj_out.contiguous(), gate_proj_out.contiguous(), - pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks) + # Step 2-3: FP32 → NVFP4 (two proven kernels) + return _quantize_fp32_to_nvfp4(compressed_fp32) def hca_compress_production_nvfp4( @@ -153,24 +205,20 @@ def hca_compress_production_nvfp4( kv_norm_weight: Optional[torch.Tensor], m: int = 128, ) -> tuple: - """HCA compress + NVFP4 quantize: single kernel, no intermediate BF16. + """HCA compress → NVFP4. No BF16 intermediate. KV-2: Production path. Compressed KV stored as NVFP4. + Pipeline: compress (FP32) → amax_gsa (GPU) → quantize (GPU) → NVFP4 triple. Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple. """ - T = kv_proj_out.shape[0] - hd = kv_proj_out.shape[1] - n_blocks = T // m - if n_blocks == 0: + # Step 1: Compress → FP32 + compressed_fp32 = hca_compress_production_fp32( + kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m) + if compressed_fp32.shape[0] == 0: dev = kv_proj_out.device + hd = kv_proj_out.shape[1] return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev), torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev), torch.zeros(0, dtype=torch.float32, device=dev)) - - from dsv4.kernels.cuda.loader import get_cuda_module - mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"]) - pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device) - norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device) - return mod.hca_compress_reduce_quant( - kv_proj_out.contiguous(), gate_proj_out.contiguous(), - pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks) + # Step 2-3: FP32 → NVFP4 + return _quantize_fp32_to_nvfp4(compressed_fp32) diff --git a/dsv4/kernels/cuda/compressor_reduce_quant.cu b/dsv4/kernels/cuda/compressor_reduce_quant.cu deleted file mode 100644 index 08fcc80a..00000000 --- a/dsv4/kernels/cuda/compressor_reduce_quant.cu +++ /dev/null @@ -1,226 +0,0 @@ -/** - * FUSED CSA/HCA compress + RMSNorm + NVFP4 quantize kernels. - * KV-1/KV-2: Single kernel launch. FP32 -> E2M1 + E4M3 + FP32 gsa. - * - * FIX: block_reduce_sum/max need n_warps shared memory slots, not a single - * float. Previous version passed &s_amax (1 float) but the functions write - * smem[0..n_warps-1], corrupting adjacent shared variables. - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -__device__ __forceinline__ float block_reduce_sum(float val, float* smem, int nw) { - for (int o = 16; o > 0; o >>= 1) val += __shfl_down_sync(0xffffffff, val, o); - if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val; - __syncthreads(); - float r = 0.0f; - if (threadIdx.x < 32) { - float v = (threadIdx.x < nw) ? smem[threadIdx.x] : 0.0f; - for (int o = 16; o > 0; o >>= 1) v += __shfl_down_sync(0xffffffff, v, o); - r = v; - } - __syncthreads(); - return r; -} - -__device__ __forceinline__ float block_reduce_max(float val, float* smem, int nw) { - for (int o = 16; o > 0; o >>= 1) val = fmaxf(val, __shfl_down_sync(0xffffffff, val, o)); - if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val; - __syncthreads(); - float r = 0.0f; - if (threadIdx.x < 32) { - float v = (threadIdx.x < nw) ? smem[threadIdx.x] : 0.0f; - for (int o = 16; o > 0; o >>= 1) v = fmaxf(v, __shfl_down_sync(0xffffffff, v, o)); - r = v; - } - __syncthreads(); - return r; -} - -__device__ __forceinline__ int half_step_to_e2m1(int hs) { - if (hs <= 4) return hs; if (hs <= 5) return 4; - if (hs <= 7) return 5; if (hs <= 10) return 6; return 7; -} - -// === CSA === -__global__ void csa_compress_reduce_quant_kernel( - const float* kv_proj, const float* gate_proj, - const float* position_bias, const float* kv_norm_weight, - uint8_t* out_fp4, uint8_t* out_sf, float* out_gsa, - int T, int hd, int m, int n_blocks -) { - int bi = blockIdx.x, tid = threadIdx.x, nt = blockDim.x; - int kd = 2*hd, nw = nt/32; - if (bi >= n_blocks) return; - - int ntok = (bi > 0) ? 2*m : m; - int ps = (bi-1)*m, cs = bi*m; - int cpt = (hd+nt-1)/nt; - - // Shared memory: reduction scratch (8 floats for n_warps=4) + FP32 staging - __shared__ float s_scratch[8]; // 0-3: sum/norm, 4-7: max/gsa - __shared__ float s_vals[512]; // FP32 staging for quantize - - float lv[4],lm[4],ld[4],la[4]; - for (int ci=0;ci=hd) break; - lm[ci]=-FLT_MAX; ld[ci]=0; la[ci]=0; - for (int t=0;t0){if(t=T) continue; - float g=gate_proj[ti*kd+go+c]; - if(position_bias){int p=(bi>0&&t0?(t-m):t);if(p>=0&&p0){if(t=T) continue; - float g=gate_proj[ti*kd+go+c], kv=kv_proj[ti*kd+ko+c]; - if(position_bias){int p=(bi>0&&t0?(t-m):t);if(p>=0&&p0)?(la[ci]/ld[ci]):0; - } - - // kv_norm - if(kv_norm_weight) { - float ls=0; for(int ci=0;ci=hd)break;ls+=lv[ci]*lv[ci];} - float ts=block_reduce_sum(ls,&s_scratch[0],nw); - if(tid==0) s_scratch[0]=rsqrtf(ts/hd+1e-6f); // s_inv_rms - __syncthreads(); - float sir=s_scratch[0]; - for(int ci=0;ci=hd)break;lv[ci]*=sir*kv_norm_weight[c];} - } - - // gsa - float ea=0; for(int ci=0;ci=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));} - float ga=block_reduce_max(ea,&s_scratch[4],nw); - float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f); - if(tid==0) out_gsa[bi]=gsa; - - // Write to shared memory for quantize - for(int ci=0;ci=hd)break;s_vals[c]=lv[ci];} - __syncthreads(); - - // Quantize: each thread handles one or more 16-element blocks - int nfb=hd/16; - for(int b=tid;b12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;} - if(!z&&c112)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;} - out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo; - } - } -} - -// === HCA === -__global__ void hca_compress_reduce_quant_kernel( - const float* kv_proj, const float* gate_proj, - const float* position_bias, const float* kv_norm_weight, - uint8_t* out_fp4, uint8_t* out_sf, float* out_gsa, - int T, int hd, int m, int n_blocks -) { - int bi=blockIdx.x,tid=threadIdx.x,nt=blockDim.x,nw=nt/32; - if(bi>=n_blocks) return; - int cpt=(hd+nt-1)/nt; - - __shared__ float s_scratch[8]; - __shared__ float s_vals[512]; - - float lv[4]; - for(int ci=0;ci=hd)break; - float lm=-FLT_MAX,ld=0,la=0; int st=bi*m; - for(int t=0;t=T)break;float g=gate_proj[ti*hd+c];if(position_bias&&t=T)break;float g=gate_proj[ti*hd+c],kv=kv_proj[ti*hd+c];if(position_bias&&t0)?(la/ld):0; - } - - if(kv_norm_weight){ - float ls=0;for(int ci=0;ci=hd)break;ls+=lv[ci]*lv[ci];} - float ts=block_reduce_sum(ls,&s_scratch[0],nw); - if(tid==0)s_scratch[0]=rsqrtf(ts/hd+1e-6f); - __syncthreads(); - float sir=s_scratch[0]; - for(int ci=0;ci=hd)break;lv[ci]*=sir*kv_norm_weight[c];} - } - - float ea=0;for(int ci=0;ci=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));} - float ga=block_reduce_max(ea,&s_scratch[4],nw); - float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f); - if(tid==0)out_gsa[bi]=gsa; - - for(int ci=0;ci=hd)break;s_vals[c]=lv[ci];} - __syncthreads(); - - int nfb=hd/16; - for(int b=tid;b12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;} - if(!z&&c112)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;} - out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo; - } - } -} - -// === Bindings === -std::tuple -csa_compress_reduce_quant_cuda(torch::Tensor kp,torch::Tensor gp,torch::Tensor pb,torch::Tensor nw,int64_t m,int64_t nb){ - int T=kp.size(0),hd=kp.size(1)/2,th=128; - const float*pp=(pb.numel()>0)?pb.data_ptr():nullptr; - const float*np=(nw.numel()>0)?nw.data_ptr():nullptr; - auto o=kp.options(); - auto of4=torch::zeros({(int)nb,hd/2},o.dtype(torch::kUInt8)); - auto osf=torch::zeros({(int)nb,hd/16},o.dtype(torch::kUInt8)); - auto ogs=torch::zeros({(int)nb},o.dtype(torch::kFloat32)); - csa_compress_reduce_quant_kernel<<>>( - kp.data_ptr(),gp.data_ptr(),pp,np, - of4.data_ptr(),osf.data_ptr(),ogs.data_ptr(), - T,hd,(int)m,(int)nb); - C10_CUDA_CHECK(cudaGetLastError()); - return {of4.view(torch::kFloat4_e2m1fn_x2),osf.view(torch::kFloat8_e4m3fn),ogs}; -} - -std::tuple -hca_compress_reduce_quant_cuda(torch::Tensor kp,torch::Tensor gp,torch::Tensor pb,torch::Tensor nw,int64_t m,int64_t nb){ - int T=kp.size(0),hd=kp.size(1),th=128; - const float*pp=(pb.numel()>0)?pb.data_ptr():nullptr; - const float*np=(nw.numel()>0)?nw.data_ptr():nullptr; - auto o=kp.options(); - auto of4=torch::zeros({(int)nb,hd/2},o.dtype(torch::kUInt8)); - auto osf=torch::zeros({(int)nb,hd/16},o.dtype(torch::kUInt8)); - auto ogs=torch::zeros({(int)nb},o.dtype(torch::kFloat32)); - hca_compress_reduce_quant_kernel<<>>( - kp.data_ptr(),gp.data_ptr(),pp,np, - of4.data_ptr(),osf.data_ptr(),ogs.data_ptr(), - T,hd,(int)m,(int)nb); - C10_CUDA_CHECK(cudaGetLastError()); - return {of4.view(torch::kFloat4_e2m1fn_x2),osf.view(torch::kFloat8_e4m3fn),ogs}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){ - m.def("csa_compress_reduce_quant",&csa_compress_reduce_quant_cuda); - m.def("hca_compress_reduce_quant",&hca_compress_reduce_quant_cuda); -} diff --git a/dsv4/kernels/cuda/kv_quantize.cu b/dsv4/kernels/cuda/kv_quantize.cu new file mode 100644 index 00000000..d29246ee --- /dev/null +++ b/dsv4/kernels/cuda/kv_quantize.cu @@ -0,0 +1,372 @@ +/** + * Quantize FP32 tensor to NVFP4. + * + * Same proven pattern as quantize_nvfp4.cu (which reads BF16), + * but takes FP32 input directly — avoids BF16 intermediate. + * + * This is the correct path for compressor output → NVFP4: + * Compressor produces FP32 → this kernel → NVFP4 stored in KV cache + * No BF16 anywhere in the pipeline. + * + * Two-kernel approach (proven correct in fused_amax_quantize.cu): + * Kernel 1: amax_gsa_fp32 — compute per-row gsa from FP32 input (GPU-only) + * Kernel 2: quantize_nvfp4_from_fp32 — quantize FP32 → NVFP4 using GPU gsa buffer + * + * Grid: (N/16, M, 1) — each CTA processes one 16-element block in one row. + * Block: 16 threads (1 thread per element, warp amax reduction). + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +__device__ __forceinline__ int half_step_to_e2m1(int hs) { + if (hs <= 4) return hs; + if (hs <= 5) return 4; + if (hs <= 7) return 5; + if (hs <= 10) return 6; + return 7; +} + +// =========================================================================== +// Kernel 1: Compute per-row amax → gsa from FP32 input +// Same pattern as amax_gsa.cu but for FP32 (not BF16) input +// =========================================================================== + +__global__ void compute_amax_gsa_fp32_kernel( + const float* __restrict__ input, + int M, int N, + float divisor, + float* __restrict__ out_gsa +) { + int m = blockIdx.x; + if (m >= M) return; + + float local_max = 0.0f; + for (int i = threadIdx.x; i < N; i += 256) { + float v = fabsf(input[m * N + i]); + local_max = fmaxf(local_max, v); + } + + // Warp-level reduction + for (int offset = 128; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, offset)); + + // Block-level reduction using shared memory + __shared__ float s_max[8]; + if (threadIdx.x % 32 == 0) + s_max[threadIdx.x / 32] = local_max; + __syncthreads(); + + if (threadIdx.x < 32) { + float v = (threadIdx.x < 8) ? s_max[threadIdx.x] : 0.0f; + for (int offset = 16; offset > 0; offset >>= 1) + v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset)); + if (threadIdx.x == 0) + out_gsa[m] = v / divisor; + } +} + +// =========================================================================== +// Kernel 2: Quantize FP32 → NVFP4 using gsa from GPU buffer +// Same proven pattern as quantize_nvfp4_from_buffer_kernel (fused_amax_quantize.cu) +// but reads FP32 instead of BF16 +// =========================================================================== + +__global__ void quantize_nvfp4_from_fp32_kernel( + const float* __restrict__ input, + int M, int N, + const float* __restrict__ gsa_buffer, // (M,) GPU buffer with per-row gsa + uint8_t* __restrict__ out_fp4, + uint8_t* __restrict__ out_sf +) { + int m = blockIdx.y; + int n_block = blockIdx.x; + if (m >= M || n_block * 16 >= N) return; + + float gsa = gsa_buffer[m]; + + float vals[16]; + float block_amax = 0.0f; + + // Step 1: Read 16 FP32 elements and compute block amax + for (int i = 0; i < 16; i++) { + int col = n_block * 16 + i; + if (col < N) { + vals[i] = input[m * N + col] / gsa; + } else { + vals[i] = 0; + } + block_amax = fmaxf(block_amax, fabsf(vals[i])); + } + + // Step 2: Compute FP8 E4M3 block scale (with FP8 round-trip) + float bsf = block_amax / 6.0f; + if (block_amax < 6.0f * 0.001953125f) { + // Zero/underflow block + bsf = 0; + for (int i = 0; i < 16; i++) vals[i] = 0; + } + __nv_fp8_e4m3 bsf8_obj(bsf); + float bs = (float)bsf8_obj; // FP8 round-trip — matches dequant + uint8_t bsf8 = *(uint8_t*)&bsf8_obj; + + // Step 3: Quantize each value to FP4 E2M1 + uint8_t nibbles[16]; + for (int i = 0; i < 16; i++) { + if (bs < 1e-8f) { nibbles[i] = 0; continue; } + float s = vals[i] / bs; + int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f); + if (hs > 12) hs = 12; + int idx = half_step_to_e2m1(hs); + if (s < 0) idx += 8; + nibbles[i] = idx; + } + + // Step 4: Pack pairs: (nibbles[1] << 4) | nibbles[0], etc. + for (int i = 0; i < 8; i++) + out_fp4[m * (N / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i]; + + // Step 5: Write FP8 block scale + out_sf[m * (N / 16) + n_block] = bsf8; +} + +// =========================================================================== +// FP32 GPT-J interleaved RoPE (for compressed KV — no BF16 intermediate) +// Same math as rope_cuda.cu but operates on FP32 directly. +// =========================================================================== + +__global__ void rope_fp32_kernel( + float* __restrict__ x, // (M, 1, N) FP32 — modified in-place + const float* __restrict__ cos_c, // (max_pos, rope_dim/2) FP32 + const float* __restrict__ sin_c, // (max_pos, rope_dim/2) FP32 + const int64_t* __restrict__ pos, // (M,) positions + int N, int rope_dim, bool inverse +) { + int m = blockIdx.x; + if (m >= gridDim.x) return; + int64_t p = pos[m]; + int nope = N - rope_dim; + for (int i = threadIdx.x; i < rope_dim / 2; i += 256) { + float c = cos_c[p * (rope_dim / 2) + i]; + float s = sin_c[p * (rope_dim / 2) + i]; + int ev_idx = m * N + nope + 2 * i; + int od_idx = m * N + nope + 2 * i + 1; + float ev = x[ev_idx]; + float od = x[od_idx]; + if (inverse) { + x[ev_idx] = ev * c + od * s; + x[od_idx] = -ev * s + od * c; + } else { + x[ev_idx] = ev * c - od * s; + x[od_idx] = ev * s + od * c; + } + } +} + +// =========================================================================== +// FP8 E4M3 quantize FP32 → FP8 (for indexer keys — higher precision) +// =========================================================================== + +__global__ void quantize_fp8_e4m3_from_fp32_kernel( + const float* __restrict__ input, + int M, int N, + float* __restrict__ out_scale, // (M,) per-row scale + uint8_t* __restrict__ out_fp8 // (M, N) packed FP8 E4M3 +) { + int m = blockIdx.x; + if (m >= M) return; + + // Per-row amax → scale = amax / 448.0 (E4M3 max = 448) + float local_max = 0.0f; + for (int i = threadIdx.x; i < N; i += 256) { + float v = fabsf(input[m * N + i]); + local_max = fmaxf(local_max, v); + } + for (int offset = 128; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, offset)); + __shared__ float s_max[8]; + if (threadIdx.x % 32 == 0) s_max[threadIdx.x / 32] = local_max; + __syncthreads(); + if (threadIdx.x < 32) { + float v = (threadIdx.x < 8) ? s_max[threadIdx.x] : 0.0f; + for (int offset = 16; offset > 0; offset >>= 1) + v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset)); + if (threadIdx.x == 0) { + float scale = v / 448.0f; + if (scale < 1e-8f) scale = 1e-8f; + out_scale[m] = scale; + } + } + __syncthreads(); + + // Quantize each element + float scale = out_scale[m]; + float inv_scale = 1.0f / scale; + for (int i = threadIdx.x; i < N; i += 256) { + float v = input[m * N + i] * inv_scale; + v = fmaxf(v, -448.0f); + v = fminf(v, 448.0f); + __nv_fp8_e4m3 obj(v); + out_fp8[m * N + i] = *(uint8_t*)&obj; + } +} + +// =========================================================================== +// FP8 E4M3 dequant → BF16 (for indexer key gather) +// =========================================================================== + +__global__ void dequant_fp8_e4m3_kernel( + const uint8_t* __restrict__ fp8_data, + const float* __restrict__ scale_data, + int M, int N, + __nv_bfloat16* __restrict__ output +) { + int m = blockIdx.x; + if (m >= M) return; + float scale = scale_data[m]; + for (int i = threadIdx.x; i < N; i += 256) { + uint8_t byte = fp8_data[m * N + i]; + __nv_fp8_e4m3 val; + memcpy(&val, &byte, 1); + float v = (float)val * scale; + output[m * N + i] = __float2bfloat16(v); + } +} + +__global__ void dequant_fp8_e4m3_selective_kernel( + const uint8_t* __restrict__ fp8_data, + const float* __restrict__ scale_data, + const int32_t* __restrict__ indices, + int K, int N, + __nv_bfloat16* __restrict__ output +) { + int k = blockIdx.x; + if (k >= K) return; + int src_row = indices[k]; + float scale = scale_data[src_row]; + for (int i = threadIdx.x; i < N; i += 256) { + uint8_t byte = fp8_data[src_row * N + i]; + __nv_fp8_e4m3 val; + memcpy(&val, &byte, 1); + float v = (float)val * scale; + output[k * N + i] = __float2bfloat16(v); + } +} + +// =========================================================================== +// PyTorch bindings +// =========================================================================== + +torch::Tensor compute_amax_gsa_fp32_cuda(torch::Tensor input, double divisor) { + int M = input.size(0); + int N = input.size(1); + auto out_gsa = torch::zeros({M}, input.options().dtype(torch::kFloat32)); + compute_amax_gsa_fp32_kernel<<>>( + input.data_ptr(), M, N, (float)divisor, out_gsa.data_ptr()); + return out_gsa; +} + +std::tuple quantize_nvfp4_from_fp32_cuda( + torch::Tensor input, torch::Tensor gsa_buffer +) { + int M = input.size(0); + int N = input.size(1); + TORCH_CHECK(N % 16 == 0, "N must be a multiple of 16 for NVFP4 quantization"); + TORCH_CHECK(gsa_buffer.size(0) == M, "gsa_buffer size must match M"); + auto opts = input.options(); + auto out_fp4 = torch::zeros({M, N / 2}, opts.dtype(torch::kUInt8)); + auto out_sf = torch::zeros({M, N / 16}, opts.dtype(torch::kUInt8)); + int nb = N / 16; + dim3 grid(nb, M); + dim3 block(16); + quantize_nvfp4_from_fp32_kernel<<>>( + input.data_ptr(), M, N, gsa_buffer.data_ptr(), + out_fp4.data_ptr(), out_sf.data_ptr() + ); + return {out_fp4.view(torch::kFloat4_e2m1fn_x2), out_sf.view(torch::kFloat8_e4m3fn)}; +} + +std::tuple quantize_fp8_e4m3_from_fp32_cuda( + torch::Tensor input +) { + int M = input.size(0); + int N = input.size(1); + auto opts = input.options(); + auto out_scale = torch::zeros({M}, opts.dtype(torch::kFloat32)); + auto out_fp8 = torch::zeros({M, N}, opts.dtype(torch::kUInt8)); + quantize_fp8_e4m3_from_fp32_kernel<<>>( + input.data_ptr(), M, N, + out_scale.data_ptr(), out_fp8.data_ptr() + ); + return {out_fp8.view(torch::kFloat8_e4m3fn), out_scale}; +} + +torch::Tensor dequant_fp8_e4m3_cuda( + torch::Tensor fp8_data, torch::Tensor scale_data +) { + int M = fp8_data.size(0); + int N = fp8_data.size(1); + auto output = torch::zeros({M, N}, fp8_data.options().dtype(torch::kBFloat16)); + dequant_fp8_e4m3_kernel<<>>( + fp8_data.data_ptr(), scale_data.data_ptr(), M, N, + reinterpret_cast<__nv_bfloat16*>(output.data_ptr()) + ); + return output; +} + +torch::Tensor dequant_fp8_e4m3_selective_cuda( + torch::Tensor fp8_data, torch::Tensor scale_data, torch::Tensor indices +) { + int K = indices.size(0); + int N = fp8_data.size(1); + TORCH_CHECK(indices.scalar_type() == torch::kInt32, "indices must be int32"); + auto output = torch::zeros({K, N}, fp8_data.options().dtype(torch::kBFloat16)); + dequant_fp8_e4m3_selective_kernel<<>>( + fp8_data.data_ptr(), scale_data.data_ptr(), + indices.data_ptr(), K, N, + reinterpret_cast<__nv_bfloat16*>(output.data_ptr()) + ); + return output; +} + +void rope_fp32_cuda( + torch::Tensor x, // (M, N) FP32 — modified in-place + torch::Tensor positions, // (M,) int64 + torch::Tensor cos_cache, // (max_pos, rope_dim/2) FP32 + torch::Tensor sin_cache, // (max_pos, rope_dim/2) FP32 + int64_t rope_dim, + bool inverse +) { + int M = x.size(0); + int N = x.size(1); + TORCH_CHECK(x.scalar_type() == torch::kFloat32, "x must be float32"); + rope_fp32_kernel<<>>( + x.data_ptr(), + cos_cache.data_ptr(), + sin_cache.data_ptr(), + positions.data_ptr(), + N, (int)rope_dim, inverse + ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("compute_amax_gsa_fp32", &compute_amax_gsa_fp32_cuda, + "Compute per-row gsa from FP32 input (GPU-only, no CPU sync)"); + m.def("quantize_nvfp4_from_fp32", &quantize_nvfp4_from_fp32_cuda, + "Quantize FP32 → NVFP4 using gsa from GPU buffer"); + m.def("quantize_fp8_e4m3_from_fp32", &quantize_fp8_e4m3_from_fp32_cuda, + "Quantize FP32 → FP8 E4M3 (for indexer keys)"); + m.def("dequant_fp8_e4m3", &dequant_fp8_e4m3_cuda, + "Dequant FP8 E4M3 → BF16"); + m.def("dequant_fp8_e4m3_selective", &dequant_fp8_e4m3_selective_cuda, + "Selective dequant FP8 E4M3 → BF16 (for CSA indexer gather)"); + m.def("rope_fp32", &rope_fp32_cuda, + "FP32 GPT-J interleaved RoPE (for compressed KV)"); +} diff --git a/single_shot_inference.py b/single_shot_inference.py index dff625ca..2e31296f 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -344,26 +344,30 @@ class Compressor: n_complete = T // r if n_complete == 0: return None, None, None - # Step 1-2: NVFP4 GEMM projections → BF16, then cast to FP32 for reduce + # Step 1-2: NVFP4 GEMM projections → FP32 for compress kv = self.kv_lin(hidden_states).float() # (T, kv_dim) FP32 gate = self.gate_lin(hidden_states).float() # (T, kv_dim) FP32 - # Position bias is handled inside the CUDA kernel (added to both kv and gate) - # Step 3: CUDA softmax/reduce kernel - from dsv4.kernels.compressor.production_compress import csa_compress_production, hca_compress_production + # Step 3: CUDA softmax/reduce kernel → FP32 + # KV-1/KV-2: Return FP32. Caller applies RoPE, then quantizes to NVFP4. + from dsv4.kernels.compressor.production_compress import csa_compress_production_fp32, hca_compress_production_fp32 if self.is_csa: - compressed = csa_compress_production( + compressed = csa_compress_production_fp32( kv, gate, self.ape, self.kv_norm_w, m=r) else: - compressed = hca_compress_production( + compressed = hca_compress_production_fp32( kv, gate, self.ape, self.kv_norm_w, m=r) if compressed.shape[0] == 0: return None, None, None + n_comp = compressed.shape[0] + # Vectorized position computation — no Python loop, no .item() - bi = torch.arange(n_complete, device=dev) + bi = torch.arange(n_comp, device=dev) pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1) comp_pos = positions[pos_idx] - return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev) + + # Return FP32 compressed output — caller handles RoPE + NVFP4 quantize + return compressed, comp_pos, torch.zeros(1, T, n_comp, dtype=torch.float32, device=dev) # ===================================================================== # Indexer — CSA top-k [PRODUCTION NVFP4 GEMMs] @@ -440,26 +444,68 @@ class Indexer: # KV Cache # ===================================================================== class KVCache: + """KV Cache with NVFP4 compressed KV and FP8_E4M3 indexer keys. + + KV-1/KV-2: Compressed KV is stored as NVFP4 (E2M1 + E4M3 + FP32 gsa). + KV-3: Indexer keys are stored as FP8_E4M3 (1 byte + per-row scale). + SWA: BF16 (only 128 tokens × 512 × 61 layers = 8MB, fits in L2). + + Storage savings vs BF16: + NVFP4: 0.5 bytes/val + 0.125 bytes/val (sf) + 4 bytes/row (gsa) + = hd/2 + hd/16 + 1 scalars per entry + = 256 + 32 + 1 = 289 bytes/entry at hd=512 + vs 1024 bytes/entry BF16 → 3.5× savings + FP8_E4M3: 1 byte/val + 4 bytes/row (scale) + = 128 + 4 = 132 bytes/entry at ihd=128 + vs 256 bytes/entry BF16 → 1.9× savings + """ def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0', indexer_key_dim=128, compress_ratio=4, indexer_top_k=1024): self.hd, self.ws, self.dev = head_dim, window_size, device self.idx_key_dim = indexer_key_dim self.ratio = compress_ratio + self.max_comp = max_comp + + # SWA: BF16 (small, fits in L2) self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device) self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device) self.swa_len, self.swa_head = 0, 0 - # P3: Pre-allocate compressed KV buffers (no more torch.cat / O(N²) growth) - self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device) + + # Compressed KV: NVFP4 storage + self.comp_kv_fp4 = torch.zeros(max_comp, head_dim // 2, dtype=torch.uint8, device=device) + self.comp_kv_sf = torch.zeros(max_comp, head_dim // 16, dtype=torch.uint8, device=device) + self.comp_kv_gsa = torch.zeros(max_comp, dtype=torch.float32, device=device) self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device) - # Indexer compressed keys: width = ihd (c_I in the paper), NOT head_dim - self.comp_idx_buf = torch.zeros(max_comp, indexer_key_dim, dtype=torch.bfloat16, device=device) - # Pre-allocated gather buffer — top_k compressed + SWA window, zero torch.cat on hot path + + # Indexer compressed keys: FP8_E4M3 + self.comp_idx_fp8 = torch.zeros(max_comp, indexer_key_dim, dtype=torch.uint8, device=device) + self.comp_idx_scale = torch.zeros(max_comp, dtype=torch.float32, device=device) + + # Pre-allocated gather buffer — top_k compressed + SWA window self.gather_buf = torch.zeros(indexer_top_k + window_size, head_dim, dtype=torch.bfloat16, device=device) self.n_comp = 0 self._has_idx = False + # Cache dequant modules (loaded once) + self._dequant_mod = None + self._kv_quant_mod = None + + def _get_dequant_mod(self): + if self._dequant_mod is None: + from dsv4.kernels.cuda.loader import get_cuda_module + self._dequant_mod = get_cuda_module( + "dequant_nvfp4", ["dequant_nvfp4.cu"]) + return self._dequant_mod + + def _get_kv_quant_mod(self): + if self._kv_quant_mod is None: + from dsv4.kernels.cuda.loader import get_cuda_module + self._kv_quant_mod = get_cuda_module( + "kv_quantize", ["kv_quantize.cu"]) + return self._kv_quant_mod + def append_swa(self, kv, pos): - """P2: Vectorized SWA append — 2 kernel launches instead of 2T.""" + """Vectorized SWA append — 2 kernel launches instead of 2T.""" T = kv.shape[0] idx = (self.swa_head + torch.arange(T, device=self.dev)) % self.ws self.swa.index_copy_(0, idx, kv) @@ -468,20 +514,77 @@ class KVCache: self.swa_len = min(self.swa_len + T, self.ws) def add_compressed(self, ckv, cpos, idx_kv=None): - """P3: Pre-allocated buffer — O(1) instead of O(N) per call.""" + """Add compressed KV entries to NVFP4 cache. + + ckv can be: + - BF16 tensor (n_comp, hd) — will be quantized to NVFP4 + - NVFP4 triple (fp4, sf, gsa) — stored directly + idx_kv can be: + - BF16 tensor (n_comp, ihd) — will be quantized to FP8_E4M3 + - FP8 triple (fp8, scale) — stored directly + """ if ckv is None: return - T = ckv.shape[0] - end = self.n_comp + T - self.comp_kv_buf[self.n_comp:end] = ckv - self.comp_pos_buf[self.n_comp:end] = cpos + end = self.n_comp + + # Handle compressed KV + if isinstance(ckv, tuple) and len(ckv) == 3: + # NVFP4 triple: (fp4, sf, gsa) + fp4, sf, gsa = ckv + T = fp4.shape[0] + self.comp_kv_fp4[end:end+T] = fp4.view(torch.uint8) if fp4.dtype != torch.uint8 else fp4 + self.comp_kv_sf[end:end+T] = sf.view(torch.uint8) if sf.dtype != torch.uint8 else sf + self.comp_kv_gsa[end:end+T] = gsa + elif isinstance(ckv, torch.Tensor): + # BF16 tensor — quantize to NVFP4 using proven two-kernel path + T = ckv.shape[0] + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + fp4, sf, gsa = quantize_nvfp4_gpu_fused(ckv) + self.comp_kv_fp4[end:end+T] = fp4.view(torch.uint8) + self.comp_kv_sf[end:end+T] = sf.view(torch.uint8) + self.comp_kv_gsa[end:end+T] = gsa + else: + raise ValueError(f"Unexpected ckv type: {type(ckv)}") + + self.comp_pos_buf[end:end+ckv.shape[0] if isinstance(ckv, torch.Tensor) else ckv[0].shape[0]] = cpos + T = ckv.shape[0] if isinstance(ckv, torch.Tensor) else ckv[0].shape[0] + + # Handle indexer keys if idx_kv is not None: - self.comp_idx_buf[self.n_comp:end] = idx_kv + if isinstance(idx_kv, tuple) and len(idx_kv) == 2: + # FP8 triple: (fp8, scale) + fp8, scale = idx_kv + self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8 + self.comp_idx_scale[end:end+T] = scale + elif isinstance(idx_kv, torch.Tensor): + # BF16 tensor — quantize to FP8_E4M3 + mod = self._get_kv_quant_mod() + fp8, scale = mod.quantize_fp8_e4m3_from_fp32(idx_kv.float().contiguous()) + self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) + self.comp_idx_scale[end:end+T] = scale self._has_idx = True - self.n_comp = end + + self.n_comp = end + T @property def comp_kv(self): - return self.comp_kv_buf[:self.n_comp] if self.n_comp > 0 else None + """Dequantize NVFP4 → BF16 for FMHA. Returns (n_comp, hd) BF16.""" + if self.n_comp == 0: return None + mod = self._get_dequant_mod() + return mod.dequant_nvfp4( + self.comp_kv_fp4[:self.n_comp], + self.comp_kv_sf[:self.n_comp], + self.comp_kv_gsa[:self.n_comp], + ) + + def comp_kv_selective(self, indices): + """Dequantize selected NVFP4 entries → BF16 for CSA top-k gather.""" + mod = self._get_dequant_mod() + return mod.dequant_nvfp4_selective( + self.comp_kv_fp4, + self.comp_kv_sf, + self.comp_kv_gsa, + indices.int(), + ) @property def comp_pos(self): @@ -489,7 +592,13 @@ class KVCache: @property def comp_idx_kv(self): - return self.comp_idx_buf[:self.n_comp] if self._has_idx and self.n_comp > 0 else None + """Dequantize FP8 indexer keys → BF16 for scoring.""" + if not self._has_idx or self.n_comp == 0: return None + mod = self._get_kv_quant_mod() + return mod.dequant_fp8_e4m3( + self.comp_idx_fp8[:self.n_comp], + self.comp_idx_scale[:self.n_comp], + ) def get_swa(self): """Return SWA KV and positions as views (no clone). Caller copies into gather_buf.""" @@ -582,18 +691,26 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, _pt('rope_kv_end') kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions) - # 3. Compressor → compressed KV + # 3. Compressor → compressed KV → FP32 RoPE → NVFP4 _pt('compress_start') - comp_kv, comp_pos, block_bias = None, None, None; comp_idx_kv = None + comp_nvfp4, comp_pos, block_bias = None, None, None; comp_idx_kv = None if compressor is not None and compressor.ratio > 0: - comp_kv, comp_pos, block_bias = compressor.forward(x_normed, positions) - if comp_kv is not None: - comp_kv_3d = comp_kv.unsqueeze(1) - comp_kv_3d = _apply_rope(comp_kv_3d, comp_pos, rope_cos, rope_sin, rd) - comp_kv = comp_kv_3d.squeeze(1) + comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, positions) + if comp_kv_fp32 is not None: + # Apply RoPE on FP32 (no BF16 intermediate) + from dsv4.kernels.cuda.loader import get_cuda_module + kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) + comp_kv_fp32_contig = comp_kv_fp32.contiguous() + c = rope_cos[comp_pos].contiguous() + s = rope_sin[comp_pos].contiguous() + kv_mod.rope_fp32(comp_kv_fp32_contig, comp_pos.contiguous(), c, s, rd, False) + # Quantize FP32 → NVFP4 (two-kernel, proven pattern) + gsa = kv_mod.compute_amax_gsa_fp32(comp_kv_fp32_contig, 6.0 * 448.0) + fp4, sf = kv_mod.quantize_nvfp4_from_fp32(comp_kv_fp32_contig, gsa) + comp_nvfp4 = (fp4, sf, gsa) if compressor.is_csa and indexer is not None and indexer.compressor is not None: comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions) - kv_cache.add_compressed(comp_kv, comp_pos, comp_idx_kv) + kv_cache.add_compressed(comp_nvfp4, comp_pos, comp_idx_kv) _pt('compress_end') # 4. Indexer top-k (CSA) @@ -601,22 +718,24 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, if indexer is not None and ratio == 4: topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li) - # 5. Gather KV — pre-allocated buffer, zero torch.cat on hot path + # 5. Gather KV — NVFP4 dequant for compressed KV _pt('gather_start') swa_kv, _swa_pos = kv_cache.get_swa() swa_len = swa_kv.shape[0] - gbuf = kv_cache.gather_buf # (indexer_top_k + window_size, hd) pre-allocated - if kv_cache.comp_kv is not None and kv_cache.n_comp > 0: + gbuf = kv_cache.gather_buf # (indexer_top_k + window_size, hd) pre-allocated BF16 + if kv_cache.n_comp > 0: if ratio == 4: + # CSA: dequant only top-k entries (bandwidth savings) assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken" - tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1) + tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1).int() n_tk = tk.shape[0] - gbuf[:n_tk] = kv_cache.comp_kv[tk] + gbuf[:n_tk] = kv_cache.comp_kv_selective(tk) # NVFP4 → BF16 gbuf[n_tk:n_tk + swa_len] = swa_kv all_kv = gbuf[:n_tk + swa_len] elif ratio > 4: + # HCA: dequant all entries (dense gather) n_comp = kv_cache.n_comp - gbuf[:n_comp] = kv_cache.comp_kv + gbuf[:n_comp] = kv_cache.comp_kv # NVFP4 → BF16 gbuf[n_comp:n_comp + swa_len] = swa_kv all_kv = gbuf[:n_comp + swa_len] else: