KV-1/KV-2/KV-3: NVFP4 compressed KV + FP8 indexer keys
Architecture: - Compressed KV: stored as NVFP4 (E2M1 + E4M3 + FP32 gsa) - Write path: compress→FP32 → FP32 RoPE → quantize FP32→NVFP4 - Read path: dequant_nvfp4/dequant_nvfp4_selective → BF16 for FMHA - No BF16 intermediate in the write path - Indexer keys: stored as FP8_E4M3 (1 byte + per-row scale) - Write path: compress→FP32 → quantize FP32→FP8_E4M3 - Read path: dequant_fp8_e4m3 → BF16 for scoring - SWA: remains BF16 (8MB total, fits in L2) New kernels in kv_quantize.cu: - compute_amax_gsa_fp32: per-row gsa from FP32 input - quantize_nvfp4_from_fp32: FP32→NVFP4 with GPU gsa buffer - quantize_fp8_e4m3_from_fp32: FP32→FP8_E4M3 for indexer keys - dequant_fp8_e4m3 / dequant_fp8_e4m3_selective: FP8→BF16 - rope_fp32: FP32 GPT-J interleaved RoPE (no BF16) Proven two-kernel pattern (same as quantize_nvfp4_gpu_fused): Kernel 1: amax_gsa (GPU-only) Kernel 2: quantize from buffer (GPU gsa) No shared memory bugs. No cross-CTA race conditions. KVCache updated: - comp_kv_fp4/sf/gsa: NVFP4 storage (3.5× smaller than BF16) - comp_idx_fp8/scale: FP8_E4M3 storage (1.9× smaller than BF16) - comp_kv property: dequant NVFP4→BF16 on demand - comp_kv_selective: dequant only top-k entries (bandwidth savings) - comp_idx_kv property: dequant FP8→BF16 on demand Removed: compressor_reduce_quant.cu (buggy single-kernel approach)
This commit is contained in:
@@ -44,11 +44,24 @@ def csa_compress_production(
|
||||
m: int = 4,
|
||||
) -> torch.Tensor:
|
||||
"""CSA compress: softmax + weighted sum + kv_norm. Returns BF16."""
|
||||
return csa_compress_production_fp32(
|
||||
kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m
|
||||
).bfloat16()
|
||||
|
||||
|
||||
def csa_compress_production_fp32(
|
||||
kv_proj_out: torch.Tensor,
|
||||
gate_proj_out: torch.Tensor,
|
||||
position_bias: Optional[torch.Tensor],
|
||||
kv_norm_weight: Optional[torch.Tensor],
|
||||
m: int = 4,
|
||||
) -> torch.Tensor:
|
||||
"""CSA compress: softmax + weighted sum + kv_norm. Returns FP32."""
|
||||
T = kv_proj_out.shape[0]
|
||||
hd = kv_proj_out.shape[1] // 2
|
||||
n_blocks = T // m
|
||||
if n_blocks == 0:
|
||||
return torch.zeros(0, hd, dtype=torch.bfloat16, device=kv_proj_out.device)
|
||||
return torch.zeros(0, hd, dtype=torch.float32, device=kv_proj_out.device)
|
||||
|
||||
mod = _get_kernel()
|
||||
|
||||
@@ -71,7 +84,7 @@ def csa_compress_production(
|
||||
m, n_blocks,
|
||||
)
|
||||
|
||||
return compressed.bfloat16()
|
||||
return compressed
|
||||
|
||||
|
||||
def hca_compress_production(
|
||||
@@ -82,11 +95,24 @@ def hca_compress_production(
|
||||
m: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""HCA compress: softmax + weighted sum + kv_norm. Returns BF16."""
|
||||
return hca_compress_production_fp32(
|
||||
kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m
|
||||
).bfloat16()
|
||||
|
||||
|
||||
def hca_compress_production_fp32(
|
||||
kv_proj_out: torch.Tensor,
|
||||
gate_proj_out: torch.Tensor,
|
||||
position_bias: Optional[torch.Tensor],
|
||||
kv_norm_weight: Optional[torch.Tensor],
|
||||
m: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""HCA compress: softmax + weighted sum + kv_norm. Returns FP32."""
|
||||
T = kv_proj_out.shape[0]
|
||||
hd = kv_proj_out.shape[1]
|
||||
n_blocks = T // m
|
||||
if n_blocks == 0:
|
||||
return torch.zeros(0, hd, dtype=torch.bfloat16, device=kv_proj_out.device)
|
||||
return torch.zeros(0, hd, dtype=torch.float32, device=kv_proj_out.device)
|
||||
|
||||
mod = _get_kernel()
|
||||
|
||||
@@ -109,13 +135,43 @@ def hca_compress_production(
|
||||
m, n_blocks,
|
||||
)
|
||||
|
||||
return compressed.bfloat16()
|
||||
return compressed
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# KV-1/KV-2: NVFP4 output variants — single kernel, no intermediate BF16
|
||||
# KV-1/KV-2: NVFP4 output — two proven kernels, no BF16 intermediate
|
||||
#
|
||||
# Architecture:
|
||||
# 1. CUDA compress kernel (compressor_reduce.cu) → FP32 compressed output
|
||||
# 2. CUDA amax_gsa_fp32 → per-row gsa (GPU-only, no CPU sync)
|
||||
# 3. CUDA quantize_nvfp4_from_fp32 → NVFP4 triple (fp4 + sf + gsa)
|
||||
#
|
||||
# This is the same two-kernel pattern that works everywhere else in the
|
||||
# pipeline (quantize_nvfp4_gpu_fused). The previous single-kernel fused
|
||||
# approach had shared memory corruption bugs. Two kernels is correct.
|
||||
#
|
||||
# Storage: NVFP4 (E2M1 data + E4M3 block scales + FP32 global scale)
|
||||
# Read path: dequant_nvfp4 / dequant_nvfp4_selective → BF16 for FMHA
|
||||
# ===========================================================================
|
||||
|
||||
def _quantize_fp32_to_nvfp4(compressed_fp32: torch.Tensor) -> tuple:
|
||||
"""Quantize FP32 compressed output → NVFP4. Two-kernel, GPU-only.
|
||||
|
||||
Uses the same proven pattern as quantize_nvfp4_gpu_fused (amax_gsa +
|
||||
quantize_from_buffer) but with FP32 input instead of BF16.
|
||||
No BF16 intermediate. No CPU sync.
|
||||
|
||||
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
|
||||
"""
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
# Kernel 1: Compute per-row gsa from FP32 input (GPU-only)
|
||||
gsa = mod.compute_amax_gsa_fp32(compressed_fp32.contiguous(), 6.0 * 448.0)
|
||||
# Kernel 2: Quantize FP32 → NVFP4 using GPU gsa buffer
|
||||
fp4, sf = mod.quantize_nvfp4_from_fp32(compressed_fp32.contiguous(), gsa)
|
||||
return fp4, sf, gsa
|
||||
|
||||
|
||||
def csa_compress_production_nvfp4(
|
||||
kv_proj_out: torch.Tensor,
|
||||
gate_proj_out: torch.Tensor,
|
||||
@@ -123,27 +179,23 @@ def csa_compress_production_nvfp4(
|
||||
kv_norm_weight: Optional[torch.Tensor],
|
||||
m: int = 4,
|
||||
) -> tuple:
|
||||
"""CSA compress + NVFP4 quantize: single kernel, no intermediate BF16.
|
||||
"""CSA compress → NVFP4. No BF16 intermediate.
|
||||
|
||||
KV-1: Production path. Compressed KV stored as NVFP4.
|
||||
Pipeline: compress (FP32) → amax_gsa (GPU) → quantize (GPU) → NVFP4 triple.
|
||||
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
|
||||
"""
|
||||
T = kv_proj_out.shape[0]
|
||||
hd = kv_proj_out.shape[1] // 2
|
||||
n_blocks = T // m
|
||||
if n_blocks == 0:
|
||||
# Step 1: Compress → FP32 (same proven kernel as BF16 path)
|
||||
compressed_fp32 = csa_compress_production_fp32(
|
||||
kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m)
|
||||
if compressed_fp32.shape[0] == 0:
|
||||
dev = kv_proj_out.device
|
||||
hd = kv_proj_out.shape[1] // 2
|
||||
return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev),
|
||||
torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev),
|
||||
torch.zeros(0, dtype=torch.float32, device=dev))
|
||||
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])
|
||||
pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
||||
norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
||||
return mod.csa_compress_reduce_quant(
|
||||
kv_proj_out.contiguous(), gate_proj_out.contiguous(),
|
||||
pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks)
|
||||
# Step 2-3: FP32 → NVFP4 (two proven kernels)
|
||||
return _quantize_fp32_to_nvfp4(compressed_fp32)
|
||||
|
||||
|
||||
def hca_compress_production_nvfp4(
|
||||
@@ -153,24 +205,20 @@ def hca_compress_production_nvfp4(
|
||||
kv_norm_weight: Optional[torch.Tensor],
|
||||
m: int = 128,
|
||||
) -> tuple:
|
||||
"""HCA compress + NVFP4 quantize: single kernel, no intermediate BF16.
|
||||
"""HCA compress → NVFP4. No BF16 intermediate.
|
||||
|
||||
KV-2: Production path. Compressed KV stored as NVFP4.
|
||||
Pipeline: compress (FP32) → amax_gsa (GPU) → quantize (GPU) → NVFP4 triple.
|
||||
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
|
||||
"""
|
||||
T = kv_proj_out.shape[0]
|
||||
hd = kv_proj_out.shape[1]
|
||||
n_blocks = T // m
|
||||
if n_blocks == 0:
|
||||
# Step 1: Compress → FP32
|
||||
compressed_fp32 = hca_compress_production_fp32(
|
||||
kv_proj_out, gate_proj_out, position_bias, kv_norm_weight, m)
|
||||
if compressed_fp32.shape[0] == 0:
|
||||
dev = kv_proj_out.device
|
||||
hd = kv_proj_out.shape[1]
|
||||
return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev),
|
||||
torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev),
|
||||
torch.zeros(0, dtype=torch.float32, device=dev))
|
||||
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])
|
||||
pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
||||
norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
||||
return mod.hca_compress_reduce_quant(
|
||||
kv_proj_out.contiguous(), gate_proj_out.contiguous(),
|
||||
pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks)
|
||||
# Step 2-3: FP32 → NVFP4
|
||||
return _quantize_fp32_to_nvfp4(compressed_fp32)
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
/**
|
||||
* FUSED CSA/HCA compress + RMSNorm + NVFP4 quantize kernels.
|
||||
* KV-1/KV-2: Single kernel launch. FP32 -> E2M1 + E4M3 + FP32 gsa.
|
||||
*
|
||||
* FIX: block_reduce_sum/max need n_warps shared memory slots, not a single
|
||||
* float. Previous version passed &s_amax (1 float) but the functions write
|
||||
* smem[0..n_warps-1], corrupting adjacent shared variables.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cmath>
|
||||
|
||||
__device__ __forceinline__ float block_reduce_sum(float val, float* smem, int nw) {
|
||||
for (int o = 16; o > 0; o >>= 1) val += __shfl_down_sync(0xffffffff, val, o);
|
||||
if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val;
|
||||
__syncthreads();
|
||||
float r = 0.0f;
|
||||
if (threadIdx.x < 32) {
|
||||
float v = (threadIdx.x < nw) ? smem[threadIdx.x] : 0.0f;
|
||||
for (int o = 16; o > 0; o >>= 1) v += __shfl_down_sync(0xffffffff, v, o);
|
||||
r = v;
|
||||
}
|
||||
__syncthreads();
|
||||
return r;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float block_reduce_max(float val, float* smem, int nw) {
|
||||
for (int o = 16; o > 0; o >>= 1) val = fmaxf(val, __shfl_down_sync(0xffffffff, val, o));
|
||||
if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val;
|
||||
__syncthreads();
|
||||
float r = 0.0f;
|
||||
if (threadIdx.x < 32) {
|
||||
float v = (threadIdx.x < nw) ? smem[threadIdx.x] : 0.0f;
|
||||
for (int o = 16; o > 0; o >>= 1) v = fmaxf(v, __shfl_down_sync(0xffffffff, v, o));
|
||||
r = v;
|
||||
}
|
||||
__syncthreads();
|
||||
return r;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
|
||||
if (hs <= 4) return hs; if (hs <= 5) return 4;
|
||||
if (hs <= 7) return 5; if (hs <= 10) return 6; return 7;
|
||||
}
|
||||
|
||||
// === CSA ===
|
||||
__global__ void csa_compress_reduce_quant_kernel(
|
||||
const float* kv_proj, const float* gate_proj,
|
||||
const float* position_bias, const float* kv_norm_weight,
|
||||
uint8_t* out_fp4, uint8_t* out_sf, float* out_gsa,
|
||||
int T, int hd, int m, int n_blocks
|
||||
) {
|
||||
int bi = blockIdx.x, tid = threadIdx.x, nt = blockDim.x;
|
||||
int kd = 2*hd, nw = nt/32;
|
||||
if (bi >= n_blocks) return;
|
||||
|
||||
int ntok = (bi > 0) ? 2*m : m;
|
||||
int ps = (bi-1)*m, cs = bi*m;
|
||||
int cpt = (hd+nt-1)/nt;
|
||||
|
||||
// Shared memory: reduction scratch (8 floats for n_warps=4) + FP32 staging
|
||||
__shared__ float s_scratch[8]; // 0-3: sum/norm, 4-7: max/gsa
|
||||
__shared__ float s_vals[512]; // FP32 staging for quantize
|
||||
|
||||
float lv[4],lm[4],ld[4],la[4];
|
||||
for (int ci=0;ci<cpt;ci++) {
|
||||
int c=tid+ci*nt; if(c>=hd) break;
|
||||
lm[ci]=-FLT_MAX; ld[ci]=0; la[ci]=0;
|
||||
for (int t=0;t<ntok;t++) {
|
||||
int ti,go; if(bi>0){if(t<m){ti=ps+t;go=0;}else{ti=cs+(t-m);go=hd;}}else{ti=t;go=hd;}
|
||||
if(ti<0||ti>=T) continue;
|
||||
float g=gate_proj[ti*kd+go+c];
|
||||
if(position_bias){int p=(bi>0&&t<m)?t:(bi>0?(t-m):t);if(p>=0&&p<m)g+=position_bias[p*kd+go+c];}
|
||||
lm[ci]=fmaxf(lm[ci],g);
|
||||
}
|
||||
for (int t=0;t<ntok;t++) {
|
||||
int ti,ko,go; if(bi>0){if(t<m){ti=ps+t;ko=0;go=0;}else{ti=cs+(t-m);ko=hd;go=hd;}}else{ti=t;ko=hd;go=hd;}
|
||||
if(ti<0||ti>=T) continue;
|
||||
float g=gate_proj[ti*kd+go+c], kv=kv_proj[ti*kd+ko+c];
|
||||
if(position_bias){int p=(bi>0&&t<m)?t:(bi>0?(t-m):t);if(p>=0&&p<m){float pb=position_bias[p*kd+go+c];g+=pb;kv+=position_bias[p*kd+ko+c];}}
|
||||
float e=expf(g-lm[ci]); ld[ci]+=e; la[ci]+=e*kv;
|
||||
}
|
||||
lv[ci]=(ld[ci]>0)?(la[ci]/ld[ci]):0;
|
||||
}
|
||||
|
||||
// kv_norm
|
||||
if(kv_norm_weight) {
|
||||
float ls=0; for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ls+=lv[ci]*lv[ci];}
|
||||
float ts=block_reduce_sum(ls,&s_scratch[0],nw);
|
||||
if(tid==0) s_scratch[0]=rsqrtf(ts/hd+1e-6f); // s_inv_rms
|
||||
__syncthreads();
|
||||
float sir=s_scratch[0];
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;lv[ci]*=sir*kv_norm_weight[c];}
|
||||
}
|
||||
|
||||
// gsa
|
||||
float ea=0; for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));}
|
||||
float ga=block_reduce_max(ea,&s_scratch[4],nw);
|
||||
float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f);
|
||||
if(tid==0) out_gsa[bi]=gsa;
|
||||
|
||||
// Write to shared memory for quantize
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;s_vals[c]=lv[ci];}
|
||||
__syncthreads();
|
||||
|
||||
// Quantize: each thread handles one or more 16-element blocks
|
||||
int nfb=hd/16;
|
||||
for(int b=tid;b<nfb;b+=nt) {
|
||||
int base=b*16;
|
||||
float ba=0; for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(s_vals[c])/gsa);}
|
||||
float bsf=ba/6.0f; bool z=(ba<6.0f*0.001953125f);
|
||||
// Quantize using FP8-round-tripped block scale (matches dequant)
|
||||
float bs_rt=0.0f;
|
||||
if(z){out_sf[bi*(hd/16)+b]=0;}
|
||||
else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;bs_rt=(float)o;}
|
||||
for(int i=0;i<8;i++){
|
||||
int c0=base+2*i,c1=base+2*i+1; uint8_t lo=0,hi=0;
|
||||
if(!z&&c0<hd){float s=(s_vals[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(s_vals[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === HCA ===
|
||||
__global__ void hca_compress_reduce_quant_kernel(
|
||||
const float* kv_proj, const float* gate_proj,
|
||||
const float* position_bias, const float* kv_norm_weight,
|
||||
uint8_t* out_fp4, uint8_t* out_sf, float* out_gsa,
|
||||
int T, int hd, int m, int n_blocks
|
||||
) {
|
||||
int bi=blockIdx.x,tid=threadIdx.x,nt=blockDim.x,nw=nt/32;
|
||||
if(bi>=n_blocks) return;
|
||||
int cpt=(hd+nt-1)/nt;
|
||||
|
||||
__shared__ float s_scratch[8];
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float lv[4];
|
||||
for(int ci=0;ci<cpt;ci++){
|
||||
int c=tid+ci*nt;if(c>=hd)break;
|
||||
float lm=-FLT_MAX,ld=0,la=0; int st=bi*m;
|
||||
for(int t=0;t<m;t++){int ti=st+t;if(ti>=T)break;float g=gate_proj[ti*hd+c];if(position_bias&&t<m)g+=position_bias[t*hd+c];lm=fmaxf(lm,g);}
|
||||
for(int t=0;t<m;t++){int ti=st+t;if(ti>=T)break;float g=gate_proj[ti*hd+c],kv=kv_proj[ti*hd+c];if(position_bias&&t<m){float pb=position_bias[t*hd+c];g+=pb;kv+=pb;}float e=expf(g-lm);ld+=e;la+=e*kv;}
|
||||
lv[ci]=(ld>0)?(la/ld):0;
|
||||
}
|
||||
|
||||
if(kv_norm_weight){
|
||||
float ls=0;for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ls+=lv[ci]*lv[ci];}
|
||||
float ts=block_reduce_sum(ls,&s_scratch[0],nw);
|
||||
if(tid==0)s_scratch[0]=rsqrtf(ts/hd+1e-6f);
|
||||
__syncthreads();
|
||||
float sir=s_scratch[0];
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;lv[ci]*=sir*kv_norm_weight[c];}
|
||||
}
|
||||
|
||||
float ea=0;for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;ea=fmaxf(ea,fabsf(lv[ci]));}
|
||||
float ga=block_reduce_max(ea,&s_scratch[4],nw);
|
||||
float gsa=fmaxf(ga,1e-8f)/(6.0f*448.0f);
|
||||
if(tid==0)out_gsa[bi]=gsa;
|
||||
|
||||
for(int ci=0;ci<cpt;ci++){int c=tid+ci*nt;if(c>=hd)break;s_vals[c]=lv[ci];}
|
||||
__syncthreads();
|
||||
|
||||
int nfb=hd/16;
|
||||
for(int b=tid;b<nfb;b+=nt){
|
||||
int base=b*16;
|
||||
float ba=0;for(int i=0;i<16;i++){int c=base+i;if(c<hd)ba=fmaxf(ba,fabsf(s_vals[c])/gsa);}
|
||||
float bsf=ba/6.0f;bool z=(ba<6.0f*0.001953125f);
|
||||
float bs_rt=0.0f;
|
||||
if(z){out_sf[bi*(hd/16)+b]=0;}
|
||||
else{__nv_fp8_e4m3 o(bsf);out_sf[bi*(hd/16)+b]=*(uint8_t*)&o;bs_rt=(float)o;}
|
||||
for(int i=0;i<8;i++){
|
||||
int c0=base+2*i,c1=base+2*i+1;uint8_t lo=0,hi=0;
|
||||
if(!z&&c0<hd){float s=(s_vals[c0]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;lo=half_step_to_e2m1(hs);if(s<0)lo+=8;}
|
||||
if(!z&&c1<hd){float s=(s_vals[c1]/gsa)/bs_rt;int hs=__float2int_rn(fminf(fabsf(s),6.0f)*2.0f);if(hs>12)hs=12;hi=half_step_to_e2m1(hs);if(s<0)hi+=8;}
|
||||
out_fp4[bi*(hd/2)+b*8+i]=(hi<<4)|lo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Bindings ===
|
||||
std::tuple<torch::Tensor,torch::Tensor,torch::Tensor>
|
||||
csa_compress_reduce_quant_cuda(torch::Tensor kp,torch::Tensor gp,torch::Tensor pb,torch::Tensor nw,int64_t m,int64_t nb){
|
||||
int T=kp.size(0),hd=kp.size(1)/2,th=128;
|
||||
const float*pp=(pb.numel()>0)?pb.data_ptr<float>():nullptr;
|
||||
const float*np=(nw.numel()>0)?nw.data_ptr<float>():nullptr;
|
||||
auto o=kp.options();
|
||||
auto of4=torch::zeros({(int)nb,hd/2},o.dtype(torch::kUInt8));
|
||||
auto osf=torch::zeros({(int)nb,hd/16},o.dtype(torch::kUInt8));
|
||||
auto ogs=torch::zeros({(int)nb},o.dtype(torch::kFloat32));
|
||||
csa_compress_reduce_quant_kernel<<<nb,th,0,c10::cuda::getCurrentCUDAStream()>>>(
|
||||
kp.data_ptr<float>(),gp.data_ptr<float>(),pp,np,
|
||||
of4.data_ptr<uint8_t>(),osf.data_ptr<uint8_t>(),ogs.data_ptr<float>(),
|
||||
T,hd,(int)m,(int)nb);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
return {of4.view(torch::kFloat4_e2m1fn_x2),osf.view(torch::kFloat8_e4m3fn),ogs};
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor,torch::Tensor,torch::Tensor>
|
||||
hca_compress_reduce_quant_cuda(torch::Tensor kp,torch::Tensor gp,torch::Tensor pb,torch::Tensor nw,int64_t m,int64_t nb){
|
||||
int T=kp.size(0),hd=kp.size(1),th=128;
|
||||
const float*pp=(pb.numel()>0)?pb.data_ptr<float>():nullptr;
|
||||
const float*np=(nw.numel()>0)?nw.data_ptr<float>():nullptr;
|
||||
auto o=kp.options();
|
||||
auto of4=torch::zeros({(int)nb,hd/2},o.dtype(torch::kUInt8));
|
||||
auto osf=torch::zeros({(int)nb,hd/16},o.dtype(torch::kUInt8));
|
||||
auto ogs=torch::zeros({(int)nb},o.dtype(torch::kFloat32));
|
||||
hca_compress_reduce_quant_kernel<<<nb,th,0,c10::cuda::getCurrentCUDAStream()>>>(
|
||||
kp.data_ptr<float>(),gp.data_ptr<float>(),pp,np,
|
||||
of4.data_ptr<uint8_t>(),osf.data_ptr<uint8_t>(),ogs.data_ptr<float>(),
|
||||
T,hd,(int)m,(int)nb);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
return {of4.view(torch::kFloat4_e2m1fn_x2),osf.view(torch::kFloat8_e4m3fn),ogs};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m){
|
||||
m.def("csa_compress_reduce_quant",&csa_compress_reduce_quant_cuda);
|
||||
m.def("hca_compress_reduce_quant",&hca_compress_reduce_quant_cuda);
|
||||
}
|
||||
372
dsv4/kernels/cuda/kv_quantize.cu
Normal file
372
dsv4/kernels/cuda/kv_quantize.cu
Normal file
@@ -0,0 +1,372 @@
|
||||
/**
|
||||
* Quantize FP32 tensor to NVFP4.
|
||||
*
|
||||
* Same proven pattern as quantize_nvfp4.cu (which reads BF16),
|
||||
* but takes FP32 input directly — avoids BF16 intermediate.
|
||||
*
|
||||
* This is the correct path for compressor output → NVFP4:
|
||||
* Compressor produces FP32 → this kernel → NVFP4 stored in KV cache
|
||||
* No BF16 anywhere in the pipeline.
|
||||
*
|
||||
* Two-kernel approach (proven correct in fused_amax_quantize.cu):
|
||||
* Kernel 1: amax_gsa_fp32 — compute per-row gsa from FP32 input (GPU-only)
|
||||
* Kernel 2: quantize_nvfp4_from_fp32 — quantize FP32 → NVFP4 using GPU gsa buffer
|
||||
*
|
||||
* Grid: (N/16, M, 1) — each CTA processes one 16-element block in one row.
|
||||
* Block: 16 threads (1 thread per element, warp amax reduction).
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdint>
|
||||
#include <cfloat>
|
||||
|
||||
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
|
||||
if (hs <= 4) return hs;
|
||||
if (hs <= 5) return 4;
|
||||
if (hs <= 7) return 5;
|
||||
if (hs <= 10) return 6;
|
||||
return 7;
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Kernel 1: Compute per-row amax → gsa from FP32 input
|
||||
// Same pattern as amax_gsa.cu but for FP32 (not BF16) input
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void compute_amax_gsa_fp32_kernel(
|
||||
const float* __restrict__ input,
|
||||
int M, int N,
|
||||
float divisor,
|
||||
float* __restrict__ out_gsa
|
||||
) {
|
||||
int m = blockIdx.x;
|
||||
if (m >= M) return;
|
||||
|
||||
float local_max = 0.0f;
|
||||
for (int i = threadIdx.x; i < N; i += 256) {
|
||||
float v = fabsf(input[m * N + i]);
|
||||
local_max = fmaxf(local_max, v);
|
||||
}
|
||||
|
||||
// Warp-level reduction
|
||||
for (int offset = 128; offset > 0; offset >>= 1)
|
||||
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, offset));
|
||||
|
||||
// Block-level reduction using shared memory
|
||||
__shared__ float s_max[8];
|
||||
if (threadIdx.x % 32 == 0)
|
||||
s_max[threadIdx.x / 32] = local_max;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < 32) {
|
||||
float v = (threadIdx.x < 8) ? s_max[threadIdx.x] : 0.0f;
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset));
|
||||
if (threadIdx.x == 0)
|
||||
out_gsa[m] = v / divisor;
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Kernel 2: Quantize FP32 → NVFP4 using gsa from GPU buffer
|
||||
// Same proven pattern as quantize_nvfp4_from_buffer_kernel (fused_amax_quantize.cu)
|
||||
// but reads FP32 instead of BF16
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void quantize_nvfp4_from_fp32_kernel(
|
||||
const float* __restrict__ input,
|
||||
int M, int N,
|
||||
const float* __restrict__ gsa_buffer, // (M,) GPU buffer with per-row gsa
|
||||
uint8_t* __restrict__ out_fp4,
|
||||
uint8_t* __restrict__ out_sf
|
||||
) {
|
||||
int m = blockIdx.y;
|
||||
int n_block = blockIdx.x;
|
||||
if (m >= M || n_block * 16 >= N) return;
|
||||
|
||||
float gsa = gsa_buffer[m];
|
||||
|
||||
float vals[16];
|
||||
float block_amax = 0.0f;
|
||||
|
||||
// Step 1: Read 16 FP32 elements and compute block amax
|
||||
for (int i = 0; i < 16; i++) {
|
||||
int col = n_block * 16 + i;
|
||||
if (col < N) {
|
||||
vals[i] = input[m * N + col] / gsa;
|
||||
} else {
|
||||
vals[i] = 0;
|
||||
}
|
||||
block_amax = fmaxf(block_amax, fabsf(vals[i]));
|
||||
}
|
||||
|
||||
// Step 2: Compute FP8 E4M3 block scale (with FP8 round-trip)
|
||||
float bsf = block_amax / 6.0f;
|
||||
if (block_amax < 6.0f * 0.001953125f) {
|
||||
// Zero/underflow block
|
||||
bsf = 0;
|
||||
for (int i = 0; i < 16; i++) vals[i] = 0;
|
||||
}
|
||||
__nv_fp8_e4m3 bsf8_obj(bsf);
|
||||
float bs = (float)bsf8_obj; // FP8 round-trip — matches dequant
|
||||
uint8_t bsf8 = *(uint8_t*)&bsf8_obj;
|
||||
|
||||
// Step 3: Quantize each value to FP4 E2M1
|
||||
uint8_t nibbles[16];
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (bs < 1e-8f) { nibbles[i] = 0; continue; }
|
||||
float s = vals[i] / bs;
|
||||
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
|
||||
if (hs > 12) hs = 12;
|
||||
int idx = half_step_to_e2m1(hs);
|
||||
if (s < 0) idx += 8;
|
||||
nibbles[i] = idx;
|
||||
}
|
||||
|
||||
// Step 4: Pack pairs: (nibbles[1] << 4) | nibbles[0], etc.
|
||||
for (int i = 0; i < 8; i++)
|
||||
out_fp4[m * (N / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i];
|
||||
|
||||
// Step 5: Write FP8 block scale
|
||||
out_sf[m * (N / 16) + n_block] = bsf8;
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// FP32 GPT-J interleaved RoPE (for compressed KV — no BF16 intermediate)
|
||||
// Same math as rope_cuda.cu but operates on FP32 directly.
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void rope_fp32_kernel(
|
||||
float* __restrict__ x, // (M, 1, N) FP32 — modified in-place
|
||||
const float* __restrict__ cos_c, // (max_pos, rope_dim/2) FP32
|
||||
const float* __restrict__ sin_c, // (max_pos, rope_dim/2) FP32
|
||||
const int64_t* __restrict__ pos, // (M,) positions
|
||||
int N, int rope_dim, bool inverse
|
||||
) {
|
||||
int m = blockIdx.x;
|
||||
if (m >= gridDim.x) return;
|
||||
int64_t p = pos[m];
|
||||
int nope = N - rope_dim;
|
||||
for (int i = threadIdx.x; i < rope_dim / 2; i += 256) {
|
||||
float c = cos_c[p * (rope_dim / 2) + i];
|
||||
float s = sin_c[p * (rope_dim / 2) + i];
|
||||
int ev_idx = m * N + nope + 2 * i;
|
||||
int od_idx = m * N + nope + 2 * i + 1;
|
||||
float ev = x[ev_idx];
|
||||
float od = x[od_idx];
|
||||
if (inverse) {
|
||||
x[ev_idx] = ev * c + od * s;
|
||||
x[od_idx] = -ev * s + od * c;
|
||||
} else {
|
||||
x[ev_idx] = ev * c - od * s;
|
||||
x[od_idx] = ev * s + od * c;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// FP8 E4M3 quantize FP32 → FP8 (for indexer keys — higher precision)
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void quantize_fp8_e4m3_from_fp32_kernel(
|
||||
const float* __restrict__ input,
|
||||
int M, int N,
|
||||
float* __restrict__ out_scale, // (M,) per-row scale
|
||||
uint8_t* __restrict__ out_fp8 // (M, N) packed FP8 E4M3
|
||||
) {
|
||||
int m = blockIdx.x;
|
||||
if (m >= M) return;
|
||||
|
||||
// Per-row amax → scale = amax / 448.0 (E4M3 max = 448)
|
||||
float local_max = 0.0f;
|
||||
for (int i = threadIdx.x; i < N; i += 256) {
|
||||
float v = fabsf(input[m * N + i]);
|
||||
local_max = fmaxf(local_max, v);
|
||||
}
|
||||
for (int offset = 128; offset > 0; offset >>= 1)
|
||||
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, offset));
|
||||
__shared__ float s_max[8];
|
||||
if (threadIdx.x % 32 == 0) s_max[threadIdx.x / 32] = local_max;
|
||||
__syncthreads();
|
||||
if (threadIdx.x < 32) {
|
||||
float v = (threadIdx.x < 8) ? s_max[threadIdx.x] : 0.0f;
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset));
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = v / 448.0f;
|
||||
if (scale < 1e-8f) scale = 1e-8f;
|
||||
out_scale[m] = scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Quantize each element
|
||||
float scale = out_scale[m];
|
||||
float inv_scale = 1.0f / scale;
|
||||
for (int i = threadIdx.x; i < N; i += 256) {
|
||||
float v = input[m * N + i] * inv_scale;
|
||||
v = fmaxf(v, -448.0f);
|
||||
v = fminf(v, 448.0f);
|
||||
__nv_fp8_e4m3 obj(v);
|
||||
out_fp8[m * N + i] = *(uint8_t*)&obj;
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// FP8 E4M3 dequant → BF16 (for indexer key gather)
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void dequant_fp8_e4m3_kernel(
|
||||
const uint8_t* __restrict__ fp8_data,
|
||||
const float* __restrict__ scale_data,
|
||||
int M, int N,
|
||||
__nv_bfloat16* __restrict__ output
|
||||
) {
|
||||
int m = blockIdx.x;
|
||||
if (m >= M) return;
|
||||
float scale = scale_data[m];
|
||||
for (int i = threadIdx.x; i < N; i += 256) {
|
||||
uint8_t byte = fp8_data[m * N + i];
|
||||
__nv_fp8_e4m3 val;
|
||||
memcpy(&val, &byte, 1);
|
||||
float v = (float)val * scale;
|
||||
output[m * N + i] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void dequant_fp8_e4m3_selective_kernel(
|
||||
const uint8_t* __restrict__ fp8_data,
|
||||
const float* __restrict__ scale_data,
|
||||
const int32_t* __restrict__ indices,
|
||||
int K, int N,
|
||||
__nv_bfloat16* __restrict__ output
|
||||
) {
|
||||
int k = blockIdx.x;
|
||||
if (k >= K) return;
|
||||
int src_row = indices[k];
|
||||
float scale = scale_data[src_row];
|
||||
for (int i = threadIdx.x; i < N; i += 256) {
|
||||
uint8_t byte = fp8_data[src_row * N + i];
|
||||
__nv_fp8_e4m3 val;
|
||||
memcpy(&val, &byte, 1);
|
||||
float v = (float)val * scale;
|
||||
output[k * N + i] = __float2bfloat16(v);
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// PyTorch bindings
|
||||
// ===========================================================================
|
||||
|
||||
torch::Tensor compute_amax_gsa_fp32_cuda(torch::Tensor input, double divisor) {
|
||||
int M = input.size(0);
|
||||
int N = input.size(1);
|
||||
auto out_gsa = torch::zeros({M}, input.options().dtype(torch::kFloat32));
|
||||
compute_amax_gsa_fp32_kernel<<<M, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
input.data_ptr<float>(), M, N, (float)divisor, out_gsa.data_ptr<float>());
|
||||
return out_gsa;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> quantize_nvfp4_from_fp32_cuda(
|
||||
torch::Tensor input, torch::Tensor gsa_buffer
|
||||
) {
|
||||
int M = input.size(0);
|
||||
int N = input.size(1);
|
||||
TORCH_CHECK(N % 16 == 0, "N must be a multiple of 16 for NVFP4 quantization");
|
||||
TORCH_CHECK(gsa_buffer.size(0) == M, "gsa_buffer size must match M");
|
||||
auto opts = input.options();
|
||||
auto out_fp4 = torch::zeros({M, N / 2}, opts.dtype(torch::kUInt8));
|
||||
auto out_sf = torch::zeros({M, N / 16}, opts.dtype(torch::kUInt8));
|
||||
int nb = N / 16;
|
||||
dim3 grid(nb, M);
|
||||
dim3 block(16);
|
||||
quantize_nvfp4_from_fp32_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
input.data_ptr<float>(), M, N, gsa_buffer.data_ptr<float>(),
|
||||
out_fp4.data_ptr<uint8_t>(), out_sf.data_ptr<uint8_t>()
|
||||
);
|
||||
return {out_fp4.view(torch::kFloat4_e2m1fn_x2), out_sf.view(torch::kFloat8_e4m3fn)};
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> quantize_fp8_e4m3_from_fp32_cuda(
|
||||
torch::Tensor input
|
||||
) {
|
||||
int M = input.size(0);
|
||||
int N = input.size(1);
|
||||
auto opts = input.options();
|
||||
auto out_scale = torch::zeros({M}, opts.dtype(torch::kFloat32));
|
||||
auto out_fp8 = torch::zeros({M, N}, opts.dtype(torch::kUInt8));
|
||||
quantize_fp8_e4m3_from_fp32_kernel<<<M, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
input.data_ptr<float>(), M, N,
|
||||
out_scale.data_ptr<float>(), out_fp8.data_ptr<uint8_t>()
|
||||
);
|
||||
return {out_fp8.view(torch::kFloat8_e4m3fn), out_scale};
|
||||
}
|
||||
|
||||
torch::Tensor dequant_fp8_e4m3_cuda(
|
||||
torch::Tensor fp8_data, torch::Tensor scale_data
|
||||
) {
|
||||
int M = fp8_data.size(0);
|
||||
int N = fp8_data.size(1);
|
||||
auto output = torch::zeros({M, N}, fp8_data.options().dtype(torch::kBFloat16));
|
||||
dequant_fp8_e4m3_kernel<<<M, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
fp8_data.data_ptr<uint8_t>(), scale_data.data_ptr<float>(), M, N,
|
||||
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>())
|
||||
);
|
||||
return output;
|
||||
}
|
||||
|
||||
torch::Tensor dequant_fp8_e4m3_selective_cuda(
|
||||
torch::Tensor fp8_data, torch::Tensor scale_data, torch::Tensor indices
|
||||
) {
|
||||
int K = indices.size(0);
|
||||
int N = fp8_data.size(1);
|
||||
TORCH_CHECK(indices.scalar_type() == torch::kInt32, "indices must be int32");
|
||||
auto output = torch::zeros({K, N}, fp8_data.options().dtype(torch::kBFloat16));
|
||||
dequant_fp8_e4m3_selective_kernel<<<K, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
fp8_data.data_ptr<uint8_t>(), scale_data.data_ptr<float>(),
|
||||
indices.data_ptr<int32_t>(), K, N,
|
||||
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>())
|
||||
);
|
||||
return output;
|
||||
}
|
||||
|
||||
void rope_fp32_cuda(
|
||||
torch::Tensor x, // (M, N) FP32 — modified in-place
|
||||
torch::Tensor positions, // (M,) int64
|
||||
torch::Tensor cos_cache, // (max_pos, rope_dim/2) FP32
|
||||
torch::Tensor sin_cache, // (max_pos, rope_dim/2) FP32
|
||||
int64_t rope_dim,
|
||||
bool inverse
|
||||
) {
|
||||
int M = x.size(0);
|
||||
int N = x.size(1);
|
||||
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "x must be float32");
|
||||
rope_fp32_kernel<<<M, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
x.data_ptr<float>(),
|
||||
cos_cache.data_ptr<float>(),
|
||||
sin_cache.data_ptr<float>(),
|
||||
positions.data_ptr<int64_t>(),
|
||||
N, (int)rope_dim, inverse
|
||||
);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("compute_amax_gsa_fp32", &compute_amax_gsa_fp32_cuda,
|
||||
"Compute per-row gsa from FP32 input (GPU-only, no CPU sync)");
|
||||
m.def("quantize_nvfp4_from_fp32", &quantize_nvfp4_from_fp32_cuda,
|
||||
"Quantize FP32 → NVFP4 using gsa from GPU buffer");
|
||||
m.def("quantize_fp8_e4m3_from_fp32", &quantize_fp8_e4m3_from_fp32_cuda,
|
||||
"Quantize FP32 → FP8 E4M3 (for indexer keys)");
|
||||
m.def("dequant_fp8_e4m3", &dequant_fp8_e4m3_cuda,
|
||||
"Dequant FP8 E4M3 → BF16");
|
||||
m.def("dequant_fp8_e4m3_selective", &dequant_fp8_e4m3_selective_cuda,
|
||||
"Selective dequant FP8 E4M3 → BF16 (for CSA indexer gather)");
|
||||
m.def("rope_fp32", &rope_fp32_cuda,
|
||||
"FP32 GPT-J interleaved RoPE (for compressed KV)");
|
||||
}
|
||||
@@ -344,26 +344,30 @@ class Compressor:
|
||||
n_complete = T // r
|
||||
if n_complete == 0: return None, None, None
|
||||
|
||||
# Step 1-2: NVFP4 GEMM projections → BF16, then cast to FP32 for reduce
|
||||
# Step 1-2: NVFP4 GEMM projections → FP32 for compress
|
||||
kv = self.kv_lin(hidden_states).float() # (T, kv_dim) FP32
|
||||
gate = self.gate_lin(hidden_states).float() # (T, kv_dim) FP32
|
||||
|
||||
# Position bias is handled inside the CUDA kernel (added to both kv and gate)
|
||||
# Step 3: CUDA softmax/reduce kernel
|
||||
from dsv4.kernels.compressor.production_compress import csa_compress_production, hca_compress_production
|
||||
# Step 3: CUDA softmax/reduce kernel → FP32
|
||||
# KV-1/KV-2: Return FP32. Caller applies RoPE, then quantizes to NVFP4.
|
||||
from dsv4.kernels.compressor.production_compress import csa_compress_production_fp32, hca_compress_production_fp32
|
||||
if self.is_csa:
|
||||
compressed = csa_compress_production(
|
||||
compressed = csa_compress_production_fp32(
|
||||
kv, gate, self.ape, self.kv_norm_w, m=r)
|
||||
else:
|
||||
compressed = hca_compress_production(
|
||||
compressed = hca_compress_production_fp32(
|
||||
kv, gate, self.ape, self.kv_norm_w, m=r)
|
||||
|
||||
if compressed.shape[0] == 0: return None, None, None
|
||||
n_comp = compressed.shape[0]
|
||||
|
||||
# Vectorized position computation — no Python loop, no .item()
|
||||
bi = torch.arange(n_complete, device=dev)
|
||||
bi = torch.arange(n_comp, device=dev)
|
||||
pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1)
|
||||
comp_pos = positions[pos_idx]
|
||||
return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev)
|
||||
|
||||
# Return FP32 compressed output — caller handles RoPE + NVFP4 quantize
|
||||
return compressed, comp_pos, torch.zeros(1, T, n_comp, dtype=torch.float32, device=dev)
|
||||
|
||||
# =====================================================================
|
||||
# Indexer — CSA top-k [PRODUCTION NVFP4 GEMMs]
|
||||
@@ -440,26 +444,68 @@ class Indexer:
|
||||
# KV Cache
|
||||
# =====================================================================
|
||||
class KVCache:
|
||||
"""KV Cache with NVFP4 compressed KV and FP8_E4M3 indexer keys.
|
||||
|
||||
KV-1/KV-2: Compressed KV is stored as NVFP4 (E2M1 + E4M3 + FP32 gsa).
|
||||
KV-3: Indexer keys are stored as FP8_E4M3 (1 byte + per-row scale).
|
||||
SWA: BF16 (only 128 tokens × 512 × 61 layers = 8MB, fits in L2).
|
||||
|
||||
Storage savings vs BF16:
|
||||
NVFP4: 0.5 bytes/val + 0.125 bytes/val (sf) + 4 bytes/row (gsa)
|
||||
= hd/2 + hd/16 + 1 scalars per entry
|
||||
= 256 + 32 + 1 = 289 bytes/entry at hd=512
|
||||
vs 1024 bytes/entry BF16 → 3.5× savings
|
||||
FP8_E4M3: 1 byte/val + 4 bytes/row (scale)
|
||||
= 128 + 4 = 132 bytes/entry at ihd=128
|
||||
vs 256 bytes/entry BF16 → 1.9× savings
|
||||
"""
|
||||
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0',
|
||||
indexer_key_dim=128, compress_ratio=4, indexer_top_k=1024):
|
||||
self.hd, self.ws, self.dev = head_dim, window_size, device
|
||||
self.idx_key_dim = indexer_key_dim
|
||||
self.ratio = compress_ratio
|
||||
self.max_comp = max_comp
|
||||
|
||||
# SWA: BF16 (small, fits in L2)
|
||||
self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device)
|
||||
self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device)
|
||||
self.swa_len, self.swa_head = 0, 0
|
||||
# P3: Pre-allocate compressed KV buffers (no more torch.cat / O(N²) growth)
|
||||
self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Compressed KV: NVFP4 storage
|
||||
self.comp_kv_fp4 = torch.zeros(max_comp, head_dim // 2, dtype=torch.uint8, device=device)
|
||||
self.comp_kv_sf = torch.zeros(max_comp, head_dim // 16, dtype=torch.uint8, device=device)
|
||||
self.comp_kv_gsa = torch.zeros(max_comp, dtype=torch.float32, device=device)
|
||||
self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device)
|
||||
# Indexer compressed keys: width = ihd (c_I in the paper), NOT head_dim
|
||||
self.comp_idx_buf = torch.zeros(max_comp, indexer_key_dim, dtype=torch.bfloat16, device=device)
|
||||
# Pre-allocated gather buffer — top_k compressed + SWA window, zero torch.cat on hot path
|
||||
|
||||
# Indexer compressed keys: FP8_E4M3
|
||||
self.comp_idx_fp8 = torch.zeros(max_comp, indexer_key_dim, dtype=torch.uint8, device=device)
|
||||
self.comp_idx_scale = torch.zeros(max_comp, dtype=torch.float32, device=device)
|
||||
|
||||
# Pre-allocated gather buffer — top_k compressed + SWA window
|
||||
self.gather_buf = torch.zeros(indexer_top_k + window_size, head_dim, dtype=torch.bfloat16, device=device)
|
||||
self.n_comp = 0
|
||||
self._has_idx = False
|
||||
|
||||
# Cache dequant modules (loaded once)
|
||||
self._dequant_mod = None
|
||||
self._kv_quant_mod = None
|
||||
|
||||
def _get_dequant_mod(self):
|
||||
if self._dequant_mod is None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
self._dequant_mod = get_cuda_module(
|
||||
"dequant_nvfp4", ["dequant_nvfp4.cu"])
|
||||
return self._dequant_mod
|
||||
|
||||
def _get_kv_quant_mod(self):
|
||||
if self._kv_quant_mod is None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
self._kv_quant_mod = get_cuda_module(
|
||||
"kv_quantize", ["kv_quantize.cu"])
|
||||
return self._kv_quant_mod
|
||||
|
||||
def append_swa(self, kv, pos):
|
||||
"""P2: Vectorized SWA append — 2 kernel launches instead of 2T."""
|
||||
"""Vectorized SWA append — 2 kernel launches instead of 2T."""
|
||||
T = kv.shape[0]
|
||||
idx = (self.swa_head + torch.arange(T, device=self.dev)) % self.ws
|
||||
self.swa.index_copy_(0, idx, kv)
|
||||
@@ -468,20 +514,77 @@ class KVCache:
|
||||
self.swa_len = min(self.swa_len + T, self.ws)
|
||||
|
||||
def add_compressed(self, ckv, cpos, idx_kv=None):
|
||||
"""P3: Pre-allocated buffer — O(1) instead of O(N) per call."""
|
||||
"""Add compressed KV entries to NVFP4 cache.
|
||||
|
||||
ckv can be:
|
||||
- BF16 tensor (n_comp, hd) — will be quantized to NVFP4
|
||||
- NVFP4 triple (fp4, sf, gsa) — stored directly
|
||||
idx_kv can be:
|
||||
- BF16 tensor (n_comp, ihd) — will be quantized to FP8_E4M3
|
||||
- FP8 triple (fp8, scale) — stored directly
|
||||
"""
|
||||
if ckv is None: return
|
||||
T = ckv.shape[0]
|
||||
end = self.n_comp + T
|
||||
self.comp_kv_buf[self.n_comp:end] = ckv
|
||||
self.comp_pos_buf[self.n_comp:end] = cpos
|
||||
end = self.n_comp
|
||||
|
||||
# Handle compressed KV
|
||||
if isinstance(ckv, tuple) and len(ckv) == 3:
|
||||
# NVFP4 triple: (fp4, sf, gsa)
|
||||
fp4, sf, gsa = ckv
|
||||
T = fp4.shape[0]
|
||||
self.comp_kv_fp4[end:end+T] = fp4.view(torch.uint8) if fp4.dtype != torch.uint8 else fp4
|
||||
self.comp_kv_sf[end:end+T] = sf.view(torch.uint8) if sf.dtype != torch.uint8 else sf
|
||||
self.comp_kv_gsa[end:end+T] = gsa
|
||||
elif isinstance(ckv, torch.Tensor):
|
||||
# BF16 tensor — quantize to NVFP4 using proven two-kernel path
|
||||
T = ckv.shape[0]
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
fp4, sf, gsa = quantize_nvfp4_gpu_fused(ckv)
|
||||
self.comp_kv_fp4[end:end+T] = fp4.view(torch.uint8)
|
||||
self.comp_kv_sf[end:end+T] = sf.view(torch.uint8)
|
||||
self.comp_kv_gsa[end:end+T] = gsa
|
||||
else:
|
||||
raise ValueError(f"Unexpected ckv type: {type(ckv)}")
|
||||
|
||||
self.comp_pos_buf[end:end+ckv.shape[0] if isinstance(ckv, torch.Tensor) else ckv[0].shape[0]] = cpos
|
||||
T = ckv.shape[0] if isinstance(ckv, torch.Tensor) else ckv[0].shape[0]
|
||||
|
||||
# Handle indexer keys
|
||||
if idx_kv is not None:
|
||||
self.comp_idx_buf[self.n_comp:end] = idx_kv
|
||||
if isinstance(idx_kv, tuple) and len(idx_kv) == 2:
|
||||
# FP8 triple: (fp8, scale)
|
||||
fp8, scale = idx_kv
|
||||
self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8
|
||||
self.comp_idx_scale[end:end+T] = scale
|
||||
elif isinstance(idx_kv, torch.Tensor):
|
||||
# BF16 tensor — quantize to FP8_E4M3
|
||||
mod = self._get_kv_quant_mod()
|
||||
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(idx_kv.float().contiguous())
|
||||
self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8)
|
||||
self.comp_idx_scale[end:end+T] = scale
|
||||
self._has_idx = True
|
||||
self.n_comp = end
|
||||
|
||||
self.n_comp = end + T
|
||||
|
||||
@property
|
||||
def comp_kv(self):
|
||||
return self.comp_kv_buf[:self.n_comp] if self.n_comp > 0 else None
|
||||
"""Dequantize NVFP4 → BF16 for FMHA. Returns (n_comp, hd) BF16."""
|
||||
if self.n_comp == 0: return None
|
||||
mod = self._get_dequant_mod()
|
||||
return mod.dequant_nvfp4(
|
||||
self.comp_kv_fp4[:self.n_comp],
|
||||
self.comp_kv_sf[:self.n_comp],
|
||||
self.comp_kv_gsa[:self.n_comp],
|
||||
)
|
||||
|
||||
def comp_kv_selective(self, indices):
|
||||
"""Dequantize selected NVFP4 entries → BF16 for CSA top-k gather."""
|
||||
mod = self._get_dequant_mod()
|
||||
return mod.dequant_nvfp4_selective(
|
||||
self.comp_kv_fp4,
|
||||
self.comp_kv_sf,
|
||||
self.comp_kv_gsa,
|
||||
indices.int(),
|
||||
)
|
||||
|
||||
@property
|
||||
def comp_pos(self):
|
||||
@@ -489,7 +592,13 @@ class KVCache:
|
||||
|
||||
@property
|
||||
def comp_idx_kv(self):
|
||||
return self.comp_idx_buf[:self.n_comp] if self._has_idx and self.n_comp > 0 else None
|
||||
"""Dequantize FP8 indexer keys → BF16 for scoring."""
|
||||
if not self._has_idx or self.n_comp == 0: return None
|
||||
mod = self._get_kv_quant_mod()
|
||||
return mod.dequant_fp8_e4m3(
|
||||
self.comp_idx_fp8[:self.n_comp],
|
||||
self.comp_idx_scale[:self.n_comp],
|
||||
)
|
||||
|
||||
def get_swa(self):
|
||||
"""Return SWA KV and positions as views (no clone). Caller copies into gather_buf."""
|
||||
@@ -582,18 +691,26 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
_pt('rope_kv_end')
|
||||
kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions)
|
||||
|
||||
# 3. Compressor → compressed KV
|
||||
# 3. Compressor → compressed KV → FP32 RoPE → NVFP4
|
||||
_pt('compress_start')
|
||||
comp_kv, comp_pos, block_bias = None, None, None; comp_idx_kv = None
|
||||
comp_nvfp4, comp_pos, block_bias = None, None, None; comp_idx_kv = None
|
||||
if compressor is not None and compressor.ratio > 0:
|
||||
comp_kv, comp_pos, block_bias = compressor.forward(x_normed, positions)
|
||||
if comp_kv is not None:
|
||||
comp_kv_3d = comp_kv.unsqueeze(1)
|
||||
comp_kv_3d = _apply_rope(comp_kv_3d, comp_pos, rope_cos, rope_sin, rd)
|
||||
comp_kv = comp_kv_3d.squeeze(1)
|
||||
comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, positions)
|
||||
if comp_kv_fp32 is not None:
|
||||
# Apply RoPE on FP32 (no BF16 intermediate)
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
comp_kv_fp32_contig = comp_kv_fp32.contiguous()
|
||||
c = rope_cos[comp_pos].contiguous()
|
||||
s = rope_sin[comp_pos].contiguous()
|
||||
kv_mod.rope_fp32(comp_kv_fp32_contig, comp_pos.contiguous(), c, s, rd, False)
|
||||
# Quantize FP32 → NVFP4 (two-kernel, proven pattern)
|
||||
gsa = kv_mod.compute_amax_gsa_fp32(comp_kv_fp32_contig, 6.0 * 448.0)
|
||||
fp4, sf = kv_mod.quantize_nvfp4_from_fp32(comp_kv_fp32_contig, gsa)
|
||||
comp_nvfp4 = (fp4, sf, gsa)
|
||||
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions)
|
||||
kv_cache.add_compressed(comp_kv, comp_pos, comp_idx_kv)
|
||||
kv_cache.add_compressed(comp_nvfp4, comp_pos, comp_idx_kv)
|
||||
_pt('compress_end')
|
||||
|
||||
# 4. Indexer top-k (CSA)
|
||||
@@ -601,22 +718,24 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
if indexer is not None and ratio == 4:
|
||||
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li)
|
||||
|
||||
# 5. Gather KV — pre-allocated buffer, zero torch.cat on hot path
|
||||
# 5. Gather KV — NVFP4 dequant for compressed KV
|
||||
_pt('gather_start')
|
||||
swa_kv, _swa_pos = kv_cache.get_swa()
|
||||
swa_len = swa_kv.shape[0]
|
||||
gbuf = kv_cache.gather_buf # (indexer_top_k + window_size, hd) pre-allocated
|
||||
if kv_cache.comp_kv is not None and kv_cache.n_comp > 0:
|
||||
gbuf = kv_cache.gather_buf # (indexer_top_k + window_size, hd) pre-allocated BF16
|
||||
if kv_cache.n_comp > 0:
|
||||
if ratio == 4:
|
||||
# CSA: dequant only top-k entries (bandwidth savings)
|
||||
assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken"
|
||||
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1)
|
||||
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1).int()
|
||||
n_tk = tk.shape[0]
|
||||
gbuf[:n_tk] = kv_cache.comp_kv[tk]
|
||||
gbuf[:n_tk] = kv_cache.comp_kv_selective(tk) # NVFP4 → BF16
|
||||
gbuf[n_tk:n_tk + swa_len] = swa_kv
|
||||
all_kv = gbuf[:n_tk + swa_len]
|
||||
elif ratio > 4:
|
||||
# HCA: dequant all entries (dense gather)
|
||||
n_comp = kv_cache.n_comp
|
||||
gbuf[:n_comp] = kv_cache.comp_kv
|
||||
gbuf[:n_comp] = kv_cache.comp_kv # NVFP4 → BF16
|
||||
gbuf[n_comp:n_comp + swa_len] = swa_kv
|
||||
all_kv = gbuf[:n_comp + swa_len]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user