KV-1/KV-2: Fused compress+NVFP4 quantize kernels + dequant
- compressor_reduce_quant.cu: Single-kernel CSA/HCA compress + RMSNorm + NVFP4 quantize. No intermediate BF16. FP32 → E2M1 + E4M3 + FP32 gsa in one kernel. Shared memory: ~2.5KB per CTA (FP32 staging + nibble buffer). - dequant_nvfp4.cu: NVFP4 → BF16 dequantization kernels. Full dequant (HCA dense gather) and selective dequant (CSA top-k gather). Single kernel launch per gather operation. - production_compress.py: Added csa_compress_production_nvfp4() and hca_compress_production_nvfp4() — production path for KV-1/KV-2. - loader.py: Preload dequant_nvfp4 and compressor_reduce_quant modules. - test_kv_compress_quant.py: Unit tests verifying cos >= 0.999 between BF16 reference and NVFP4 round-trip path.
This commit is contained in:
@@ -6,6 +6,9 @@ Pipeline:
|
||||
3. CUDA kernel: token-level softmax(gate) * kv → compressed entries
|
||||
4. CUDA kernel: kv_norm (unweighted RMSNorm + weight)
|
||||
|
||||
KV-1/KV-2: NVFP4 output variants compress + quantize in a single kernel.
|
||||
No intermediate BF16. Stored as FP4 data + E4M3 block scales + FP32 global scale.
|
||||
|
||||
No PyTorch softmax. No reference fallback. All on the GPU.
|
||||
"""
|
||||
|
||||
@@ -40,18 +43,7 @@ def csa_compress_production(
|
||||
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
|
||||
m: int = 4,
|
||||
) -> torch.Tensor:
|
||||
"""CSA compress: softmax + weighted sum + kv_norm.
|
||||
|
||||
Args:
|
||||
kv_proj_out: FP32 projection output, (T, 2*hd), Ca in first hd cols, Cb in second
|
||||
gate_proj_out: FP32 projection output, (T, 2*hd), Ga in first hd cols, Gb in second
|
||||
position_bias: (m, 2*hd) BF16 position bias, or None
|
||||
kv_norm_weight: (hd) BF16 norm weight, or None
|
||||
m: compression ratio (4 for CSA)
|
||||
|
||||
Returns:
|
||||
compressed: (n_blocks, hd) BF16
|
||||
"""
|
||||
"""CSA compress: softmax + weighted sum + kv_norm. Returns BF16."""
|
||||
T = kv_proj_out.shape[0]
|
||||
hd = kv_proj_out.shape[1] // 2
|
||||
n_blocks = T // m
|
||||
@@ -60,7 +52,6 @@ def csa_compress_production(
|
||||
|
||||
mod = _get_kernel()
|
||||
|
||||
# Convert position_bias and kv_norm_weight to FP32
|
||||
pos_bias_f32 = torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
||||
if position_bias is not None:
|
||||
pos_bias_f32 = position_bias.float()
|
||||
@@ -90,18 +81,7 @@ def hca_compress_production(
|
||||
kv_norm_weight: Optional[torch.Tensor], # (hd) BF16 or None
|
||||
m: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""HCA compress: softmax + weighted sum + kv_norm.
|
||||
|
||||
Args:
|
||||
kv_proj_out: FP32 projection output, (T, hd)
|
||||
gate_proj_out: FP32 projection output, (T, hd)
|
||||
position_bias: (m, hd) BF16 position bias, or None
|
||||
kv_norm_weight: (hd) BF16 norm weight, or None
|
||||
m: compression ratio (128 for HCA)
|
||||
|
||||
Returns:
|
||||
compressed: (n_blocks, hd) BF16
|
||||
"""
|
||||
"""HCA compress: softmax + weighted sum + kv_norm. Returns BF16."""
|
||||
T = kv_proj_out.shape[0]
|
||||
hd = kv_proj_out.shape[1]
|
||||
n_blocks = T // m
|
||||
@@ -130,3 +110,67 @@ def hca_compress_production(
|
||||
)
|
||||
|
||||
return compressed.bfloat16()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# KV-1/KV-2: NVFP4 output variants — single kernel, no intermediate BF16
|
||||
# ===========================================================================
|
||||
|
||||
def csa_compress_production_nvfp4(
|
||||
kv_proj_out: torch.Tensor,
|
||||
gate_proj_out: torch.Tensor,
|
||||
position_bias: Optional[torch.Tensor],
|
||||
kv_norm_weight: Optional[torch.Tensor],
|
||||
m: int = 4,
|
||||
) -> tuple:
|
||||
"""CSA compress + NVFP4 quantize: single kernel, no intermediate BF16.
|
||||
|
||||
KV-1: Production path. Compressed KV stored as NVFP4.
|
||||
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
|
||||
"""
|
||||
T = kv_proj_out.shape[0]
|
||||
hd = kv_proj_out.shape[1] // 2
|
||||
n_blocks = T // m
|
||||
if n_blocks == 0:
|
||||
dev = kv_proj_out.device
|
||||
return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev),
|
||||
torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev),
|
||||
torch.zeros(0, dtype=torch.float32, device=dev))
|
||||
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])
|
||||
pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
||||
norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
||||
return mod.csa_compress_reduce_quant(
|
||||
kv_proj_out.contiguous(), gate_proj_out.contiguous(),
|
||||
pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks)
|
||||
|
||||
|
||||
def hca_compress_production_nvfp4(
|
||||
kv_proj_out: torch.Tensor,
|
||||
gate_proj_out: torch.Tensor,
|
||||
position_bias: Optional[torch.Tensor],
|
||||
kv_norm_weight: Optional[torch.Tensor],
|
||||
m: int = 128,
|
||||
) -> tuple:
|
||||
"""HCA compress + NVFP4 quantize: single kernel, no intermediate BF16.
|
||||
|
||||
KV-2: Production path. Compressed KV stored as NVFP4.
|
||||
Returns: (fp4_data, block_scales, global_scales) — NVFP4 triple.
|
||||
"""
|
||||
T = kv_proj_out.shape[0]
|
||||
hd = kv_proj_out.shape[1]
|
||||
n_blocks = T // m
|
||||
if n_blocks == 0:
|
||||
dev = kv_proj_out.device
|
||||
return (torch.zeros(0, hd // 2, dtype=torch.float4_e2m1fn_x2, device=dev),
|
||||
torch.zeros(0, hd // 16, dtype=torch.float8_e4m3fn, device=dev),
|
||||
torch.zeros(0, dtype=torch.float32, device=dev))
|
||||
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])
|
||||
pos_bias_f32 = position_bias.float() if position_bias is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
||||
norm_f32 = kv_norm_weight.float() if kv_norm_weight is not None else torch.empty(0, dtype=torch.float32, device=kv_proj_out.device)
|
||||
return mod.hca_compress_reduce_quant(
|
||||
kv_proj_out.contiguous(), gate_proj_out.contiguous(),
|
||||
pos_bias_f32.contiguous(), norm_f32.contiguous(), m, n_blocks)
|
||||
|
||||
461
dsv4/kernels/cuda/compressor_reduce_quant.cu
Normal file
461
dsv4/kernels/cuda/compressor_reduce_quant.cu
Normal file
@@ -0,0 +1,461 @@
|
||||
/**
|
||||
* FUSED CSA/HCA compress + RMSNorm + NVFP4 quantize kernels.
|
||||
*
|
||||
* KV-1/KV-2: Single kernel launch per compressed entry.
|
||||
* The compressor produces FP32 values, applies kv_norm, then quantizes
|
||||
* to NVFP4 (E2M1 data + E4M3 block scales + FP32 global scale) all in
|
||||
* one kernel. No intermediate BF16 materialization.
|
||||
*
|
||||
* Shared memory budget per CTA (128 threads, hd=512):
|
||||
* s_vals: hd * 4 = 2048 bytes (FP32 staging)
|
||||
* s_nibbles: hd * 1 = 512 bytes (E2M1 nibbles)
|
||||
* s_sq/s_amax/s_inv_rms: ~16 bytes (reduction scratch)
|
||||
* Total: ~2576 bytes — well within 48KB
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cmath>
|
||||
|
||||
// ===========================================================================
|
||||
// Shared utilities
|
||||
// ===========================================================================
|
||||
|
||||
__device__ __forceinline__ float block_reduce_sum(float val, float* smem, int n_warps) {
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val;
|
||||
__syncthreads();
|
||||
float result = 0.0f;
|
||||
if (threadIdx.x < 32) {
|
||||
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f;
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
v += __shfl_down_sync(0xffffffff, v, offset);
|
||||
result = v;
|
||||
}
|
||||
__syncthreads();
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float block_reduce_max(float val, float* smem, int n_warps) {
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
|
||||
if (threadIdx.x % 32 == 0) smem[threadIdx.x / 32] = val;
|
||||
__syncthreads();
|
||||
float result = 0.0f;
|
||||
if (threadIdx.x < 32) {
|
||||
float v = (threadIdx.x < n_warps) ? smem[threadIdx.x] : 0.0f;
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, offset));
|
||||
result = v;
|
||||
}
|
||||
__syncthreads();
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
|
||||
if (hs <= 4) return hs;
|
||||
if (hs <= 5) return 4;
|
||||
if (hs <= 7) return 5;
|
||||
if (hs <= 10) return 6;
|
||||
return 7;
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// CSA fused compress + norm + quantize
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void csa_compress_reduce_quant_kernel(
|
||||
const float* __restrict__ kv_proj, // [T, 2*hd] FP32
|
||||
const float* __restrict__ gate_proj, // [T, 2*hd] FP32
|
||||
const float* __restrict__ position_bias, // [m, 2*hd] FP32 or nullptr
|
||||
const float* __restrict__ kv_norm_weight, // [hd] FP32 or nullptr
|
||||
uint8_t* __restrict__ out_fp4, // (n_blocks, hd/2) packed E2M1
|
||||
uint8_t* __restrict__ out_sf, // (n_blocks, hd/16) E4M3 block scales
|
||||
float* __restrict__ out_gsa, // (n_blocks,) FP32 global scale
|
||||
int T, int hd, int m, int n_blocks
|
||||
) {
|
||||
int block_i = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int n_threads = blockDim.x;
|
||||
int kv_dim = 2 * hd;
|
||||
int n_warps = n_threads / 32;
|
||||
|
||||
if (block_i >= n_blocks) return;
|
||||
|
||||
int n_tokens = (block_i > 0) ? 2 * m : m;
|
||||
int prev_start = (block_i - 1) * m;
|
||||
int cur_start = block_i * m;
|
||||
int cols_per_thread = (hd + n_threads - 1) / n_threads;
|
||||
|
||||
// ---- Phase 1: Softmax + weighted sum ----
|
||||
float local_vals[4], local_max[4], local_denom[4], local_acc[4];
|
||||
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
local_max[ci] = -FLT_MAX;
|
||||
local_denom[ci] = 0.0f;
|
||||
local_acc[ci] = 0.0f;
|
||||
|
||||
// Pass 1: max gate
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
int token_idx, gate_offset;
|
||||
if (block_i > 0) {
|
||||
if (t < m) { token_idx = prev_start + t; gate_offset = 0; }
|
||||
else { token_idx = cur_start + (t - m); gate_offset = hd; }
|
||||
} else { token_idx = t; gate_offset = hd; }
|
||||
if (token_idx < 0 || token_idx >= T) continue;
|
||||
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
|
||||
if (position_bias) {
|
||||
int pbr = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
|
||||
if (pbr >= 0 && pbr < m) g += position_bias[pbr * kv_dim + gate_offset + c];
|
||||
}
|
||||
local_max[ci] = fmaxf(local_max[ci], g);
|
||||
}
|
||||
|
||||
// Pass 2: exp + weighted sum
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
int token_idx, kv_offset, gate_offset;
|
||||
if (block_i > 0) {
|
||||
if (t < m) { token_idx = prev_start + t; kv_offset = 0; gate_offset = 0; }
|
||||
else { token_idx = cur_start + (t - m); kv_offset = hd; gate_offset = hd; }
|
||||
} else { token_idx = t; kv_offset = hd; gate_offset = hd; }
|
||||
if (token_idx < 0 || token_idx >= T) continue;
|
||||
float g = gate_proj[token_idx * kv_dim + gate_offset + c];
|
||||
float kv_val = kv_proj[token_idx * kv_dim + kv_offset + c];
|
||||
if (position_bias) {
|
||||
int pbr = (block_i > 0 && t < m) ? t : (block_i > 0 ? (t - m) : t);
|
||||
if (pbr >= 0 && pbr < m) {
|
||||
float pb = position_bias[pbr * kv_dim + gate_offset + c];
|
||||
g += pb;
|
||||
kv_val += position_bias[pbr * kv_dim + kv_offset + c];
|
||||
}
|
||||
}
|
||||
float e = expf(g - local_max[ci]);
|
||||
local_denom[ci] += e;
|
||||
local_acc[ci] += e * kv_val;
|
||||
}
|
||||
local_vals[ci] = (local_denom[ci] > 0.0f) ? (local_acc[ci] / local_denom[ci]) : 0.0f;
|
||||
}
|
||||
|
||||
// ---- Phase 2: kv_norm (RMSNorm) ----
|
||||
if (kv_norm_weight) {
|
||||
float local_sq = 0.0f;
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
local_sq += local_vals[ci] * local_vals[ci];
|
||||
}
|
||||
__shared__ float s_sq;
|
||||
float total_sq = block_reduce_sum(local_sq, &s_sq, n_warps);
|
||||
__shared__ float s_inv_rms;
|
||||
if (tid == 0) s_inv_rms = rsqrtf(total_sq / hd + 1e-6f);
|
||||
__syncthreads();
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
local_vals[ci] *= s_inv_rms * kv_norm_weight[c];
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Phase 3: Global scale (gsa) ----
|
||||
float entry_amax = 0.0f;
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
entry_amax = fmaxf(entry_amax, fabsf(local_vals[ci]));
|
||||
}
|
||||
__shared__ float s_amax;
|
||||
float global_amax = block_reduce_max(entry_amax, &s_amax, n_warps);
|
||||
float gsa = fmaxf(global_amax, 1e-8f) / (6.0f * 448.0f);
|
||||
if (tid == 0) out_gsa[block_i] = gsa;
|
||||
|
||||
// ---- Phase 4: NVFP4 quantize via shared memory ----
|
||||
__shared__ float s_vals[512]; // FP32 staging
|
||||
__shared__ uint8_t s_nib[512]; // E2M1 nibbles
|
||||
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
s_vals[c] = local_vals[ci];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int n_fp4_blocks = hd / 16;
|
||||
int tpb = n_threads / n_fp4_blocks; // threads per fp4 block
|
||||
int my_b = tid / tpb;
|
||||
int my_l = tid % tpb;
|
||||
|
||||
if (my_b < n_fp4_blocks) {
|
||||
int base = my_b * 16;
|
||||
|
||||
// Block amax
|
||||
float bamax = 0.0f;
|
||||
for (int i = my_l; i < 16; i += tpb) {
|
||||
int c = base + i;
|
||||
if (c < hd) bamax = fmaxf(bamax, fabsf(s_vals[c]) / gsa);
|
||||
}
|
||||
for (int off = tpb / 2; off > 0; off >>= 1)
|
||||
bamax = fmaxf(bamax, __shfl_down_sync(0xffffffff, bamax, off));
|
||||
float fbamax = __shfl_sync(0xffffffff, bamax, 0);
|
||||
|
||||
float bsf = fbamax / 6.0f;
|
||||
bool zero_blk = (fbamax < 6.0f * 0.001953125f);
|
||||
|
||||
if (my_l == 0) {
|
||||
if (zero_blk) {
|
||||
out_sf[block_i * (hd / 16) + my_b] = 0;
|
||||
} else {
|
||||
__nv_fp8_e4m3 obj(bsf);
|
||||
out_sf[block_i * (hd / 16) + my_b] = *(uint8_t*)&obj;
|
||||
}
|
||||
}
|
||||
|
||||
// Quantize to E2M1 nibbles
|
||||
for (int i = my_l; i < 16; i += tpb) {
|
||||
int c = base + i;
|
||||
if (c >= hd || zero_blk) { s_nib[c] = 0; continue; }
|
||||
float s = (s_vals[c] / gsa) / bsf;
|
||||
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
|
||||
if (hs > 12) hs = 12;
|
||||
int idx = half_step_to_e2m1(hs);
|
||||
if (s < 0) idx += 8;
|
||||
s_nib[c] = idx;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Pack and write
|
||||
if (my_b < n_fp4_blocks && my_l == 0) {
|
||||
int base = my_b * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint8_t lo = s_nib[base + 2 * i] & 0x0F;
|
||||
uint8_t hi = s_nib[base + 2 * i + 1] & 0x0F;
|
||||
out_fp4[block_i * (hd / 2) + my_b * 8 + i] = (hi << 4) | lo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// HCA fused compress + norm + quantize (simpler — no overlap)
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void hca_compress_reduce_quant_kernel(
|
||||
const float* __restrict__ kv_proj,
|
||||
const float* __restrict__ gate_proj,
|
||||
const float* __restrict__ position_bias,
|
||||
const float* __restrict__ kv_norm_weight,
|
||||
uint8_t* __restrict__ out_fp4,
|
||||
uint8_t* __restrict__ out_sf,
|
||||
float* __restrict__ out_gsa,
|
||||
int T, int hd, int m, int n_blocks
|
||||
) {
|
||||
int block_i = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
int n_threads = blockDim.x;
|
||||
int n_warps = n_threads / 32;
|
||||
|
||||
if (block_i >= n_blocks) return;
|
||||
|
||||
int cols_per_thread = (hd + n_threads - 1) / n_threads;
|
||||
|
||||
// Phase 1: Softmax + weighted sum
|
||||
float local_vals[4];
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
float lmax = -FLT_MAX, ldenom = 0.0f, lacc = 0.0f;
|
||||
int start = block_i * m;
|
||||
for (int t = 0; t < m; t++) {
|
||||
int ti = start + t;
|
||||
if (ti >= T) break;
|
||||
float g = gate_proj[ti * hd + c];
|
||||
if (position_bias && t < m) g += position_bias[t * hd + c];
|
||||
lmax = fmaxf(lmax, g);
|
||||
}
|
||||
for (int t = 0; t < m; t++) {
|
||||
int ti = start + t;
|
||||
if (ti >= T) break;
|
||||
float g = gate_proj[ti * hd + c];
|
||||
float kv = kv_proj[ti * hd + c];
|
||||
if (position_bias && t < m) { float pb = position_bias[t * hd + c]; g += pb; kv += pb; }
|
||||
float e = expf(g - lmax);
|
||||
ldenom += e; lacc += e * kv;
|
||||
}
|
||||
local_vals[ci] = (ldenom > 0.0f) ? (lacc / ldenom) : 0.0f;
|
||||
}
|
||||
|
||||
// Phase 2: kv_norm
|
||||
if (kv_norm_weight) {
|
||||
float lsq = 0.0f;
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
lsq += local_vals[ci] * local_vals[ci];
|
||||
}
|
||||
__shared__ float s_sq;
|
||||
float tsq = block_reduce_sum(lsq, &s_sq, n_warps);
|
||||
__shared__ float s_inv_rms;
|
||||
if (tid == 0) s_inv_rms = rsqrtf(tsq / hd + 1e-6f);
|
||||
__syncthreads();
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
local_vals[ci] *= s_inv_rms * kv_norm_weight[c];
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: gsa
|
||||
float eamax = 0.0f;
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
eamax = fmaxf(eamax, fabsf(local_vals[ci]));
|
||||
}
|
||||
__shared__ float s_amax;
|
||||
float gamax = block_reduce_max(eamax, &s_amax, n_warps);
|
||||
float gsa = fmaxf(gamax, 1e-8f) / (6.0f * 448.0f);
|
||||
if (tid == 0) out_gsa[block_i] = gsa;
|
||||
|
||||
// Phase 4: NVFP4 quantize
|
||||
__shared__ float s_vals[512];
|
||||
__shared__ uint8_t s_nib[512];
|
||||
for (int ci = 0; ci < cols_per_thread; ci++) {
|
||||
int c = tid + ci * n_threads;
|
||||
if (c >= hd) break;
|
||||
s_vals[c] = local_vals[ci];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int nfb = hd / 16;
|
||||
int tpb = n_threads / nfb;
|
||||
int my_b = tid / tpb;
|
||||
int my_l = tid % tpb;
|
||||
|
||||
if (my_b < nfb) {
|
||||
int base = my_b * 16;
|
||||
float bamax = 0.0f;
|
||||
for (int i = my_l; i < 16; i += tpb) {
|
||||
int c = base + i;
|
||||
if (c < hd) bamax = fmaxf(bamax, fabsf(s_vals[c]) / gsa);
|
||||
}
|
||||
for (int off = tpb / 2; off > 0; off >>= 1)
|
||||
bamax = fmaxf(bamax, __shfl_down_sync(0xffffffff, bamax, off));
|
||||
float fbamax = __shfl_sync(0xffffffff, bamax, 0);
|
||||
float bsf = fbamax / 6.0f;
|
||||
bool zblk = (fbamax < 6.0f * 0.001953125f);
|
||||
if (my_l == 0) {
|
||||
if (zblk) { out_sf[block_i * (hd / 16) + my_b] = 0; }
|
||||
else { __nv_fp8_e4m3 obj(bsf); out_sf[block_i * (hd / 16) + my_b] = *(uint8_t*)&obj; }
|
||||
}
|
||||
for (int i = my_l; i < 16; i += tpb) {
|
||||
int c = base + i;
|
||||
if (c >= hd || zblk) { s_nib[c] = 0; continue; }
|
||||
float s = (s_vals[c] / gsa) / bsf;
|
||||
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
|
||||
if (hs > 12) hs = 12;
|
||||
int idx = half_step_to_e2m1(hs);
|
||||
if (s < 0) idx += 8;
|
||||
s_nib[c] = idx;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (my_b < nfb && my_l == 0) {
|
||||
int base = my_b * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint8_t lo = s_nib[base + 2 * i] & 0x0F;
|
||||
uint8_t hi = s_nib[base + 2 * i + 1] & 0x0F;
|
||||
out_fp4[block_i * (hd / 2) + my_b * 8 + i] = (hi << 4) | lo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// PyTorch bindings
|
||||
// ===========================================================================
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
csa_compress_reduce_quant_cuda(
|
||||
torch::Tensor kv_proj, // [T, 2*hd] FP32
|
||||
torch::Tensor gate_proj, // [T, 2*hd] FP32
|
||||
torch::Tensor position_bias, // [m, 2*hd] FP32 or empty
|
||||
torch::Tensor kv_norm_weight, // [hd] FP32 or empty
|
||||
int64_t m, int64_t n_blocks
|
||||
) {
|
||||
int T = kv_proj.size(0);
|
||||
int hd = kv_proj.size(1) / 2;
|
||||
int threads = 128;
|
||||
|
||||
const float* pos_ptr = (position_bias.numel() > 0) ? position_bias.data_ptr<float>() : nullptr;
|
||||
const float* norm_ptr = (kv_norm_weight.numel() > 0) ? kv_norm_weight.data_ptr<float>() : nullptr;
|
||||
|
||||
auto opts = kv_proj.options();
|
||||
auto out_fp4 = torch::zeros({(int)n_blocks, hd / 2}, opts.dtype(torch::kUInt8));
|
||||
auto out_sf = torch::zeros({(int)n_blocks, hd / 16}, opts.dtype(torch::kUInt8));
|
||||
auto out_gsa = torch::zeros({(int)n_blocks}, opts.dtype(torch::kFloat32));
|
||||
|
||||
csa_compress_reduce_quant_kernel<<<n_blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
kv_proj.data_ptr<float>(),
|
||||
gate_proj.data_ptr<float>(),
|
||||
pos_ptr, norm_ptr,
|
||||
out_fp4.data_ptr<uint8_t>(),
|
||||
out_sf.data_ptr<uint8_t>(),
|
||||
out_gsa.data_ptr<float>(),
|
||||
T, hd, (int)m, (int)n_blocks
|
||||
);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return {out_fp4.view(torch::kFloat4_e2m1fn_x2),
|
||||
out_sf.view(torch::kFloat8_e4m3fn),
|
||||
out_gsa};
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
hca_compress_reduce_quant_cuda(
|
||||
torch::Tensor kv_proj,
|
||||
torch::Tensor gate_proj,
|
||||
torch::Tensor position_bias,
|
||||
torch::Tensor kv_norm_weight,
|
||||
int64_t m, int64_t n_blocks
|
||||
) {
|
||||
int T = kv_proj.size(0);
|
||||
int hd = kv_proj.size(1);
|
||||
int threads = 128;
|
||||
|
||||
const float* pos_ptr = (position_bias.numel() > 0) ? position_bias.data_ptr<float>() : nullptr;
|
||||
const float* norm_ptr = (kv_norm_weight.numel() > 0) ? kv_norm_weight.data_ptr<float>() : nullptr;
|
||||
|
||||
auto opts = kv_proj.options();
|
||||
auto out_fp4 = torch::zeros({(int)n_blocks, hd / 2}, opts.dtype(torch::kUInt8));
|
||||
auto out_sf = torch::zeros({(int)n_blocks, hd / 16}, opts.dtype(torch::kUInt8));
|
||||
auto out_gsa = torch::zeros({(int)n_blocks}, opts.dtype(torch::kFloat32));
|
||||
|
||||
hca_compress_reduce_quant_kernel<<<n_blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
kv_proj.data_ptr<float>(),
|
||||
gate_proj.data_ptr<float>(),
|
||||
pos_ptr, norm_ptr,
|
||||
out_fp4.data_ptr<uint8_t>(),
|
||||
out_sf.data_ptr<uint8_t>(),
|
||||
out_gsa.data_ptr<float>(),
|
||||
T, hd, (int)m, (int)n_blocks
|
||||
);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return {out_fp4.view(torch::kFloat4_e2m1fn_x2),
|
||||
out_sf.view(torch::kFloat8_e4m3fn),
|
||||
out_gsa};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("csa_compress_reduce_quant", &csa_compress_reduce_quant_cuda,
|
||||
"Fused CSA compress + norm + NVFP4 quantize");
|
||||
m.def("hca_compress_reduce_quant", &hca_compress_reduce_quant_cuda,
|
||||
"Fused HCA compress + norm + NVFP4 quantize");
|
||||
}
|
||||
192
dsv4/kernels/cuda/dequant_nvfp4.cu
Normal file
192
dsv4/kernels/cuda/dequant_nvfp4.cu
Normal file
@@ -0,0 +1,192 @@
|
||||
/**
|
||||
* NVFP4 → BF16 dequantization kernels.
|
||||
*
|
||||
* Converts FP4 (E2M1) data + FP8 (E4M3) block scales + FP32 global scales
|
||||
* back to BF16. Used for the FMHA gather path: compressed KV is stored as
|
||||
* NVFP4, and dequantized on-the-fly when gathering for attention.
|
||||
*
|
||||
* Two variants:
|
||||
* 1. Full dequant: entire FP4 buffer → BF16 (for HCA dense gather)
|
||||
* 2. Selective dequant: only selected rows → BF16 (for CSA top-k gather)
|
||||
*
|
||||
* Grid layout: (N/16, M) — one CTA per (row, 16-element block).
|
||||
* Block size: 16 threads (1 thread per element in the 16-wide block).
|
||||
*
|
||||
* Memory savings: FP4 is 4× smaller than BF16. At hd=512:
|
||||
* BF16: 512 × 2 = 1024 bytes per entry
|
||||
* NVFP4: 256 + 64 + 4 = 324 bytes per entry (fp4 + sf + gsa)
|
||||
* Savings: ~3.2×
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdint>
|
||||
|
||||
// E2M1 magnitudes: index 0-7 → 0, 0.5, 1, 1.5, 2, 3, 4, 6
|
||||
__device__ __constant__ float E2M1_LUT[8] = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f};
|
||||
|
||||
// ===========================================================================
|
||||
// Full dequant: entire buffer → BF16
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void dequant_nvfp4_kernel(
|
||||
const uint8_t* __restrict__ fp4_data, // (M, N/2) packed E2M1
|
||||
const uint8_t* __restrict__ sf_data, // (M, N/16) E4M3 block scales (stored as uint8)
|
||||
const float* __restrict__ gsa_data, // (M,) FP32 global scale per row
|
||||
__nv_bfloat16* __restrict__ output, // (M, N) BF16 output
|
||||
int M, int N
|
||||
) {
|
||||
int m = blockIdx.y;
|
||||
int n_block = blockIdx.x;
|
||||
if (m >= M || n_block * 16 >= N) return;
|
||||
|
||||
float gsa = gsa_data[m];
|
||||
|
||||
// Read FP8 E4M3 block scale
|
||||
uint8_t sf_byte = sf_data[m * (N / 16) + n_block];
|
||||
__nv_fp8_e4m3 sf_val;
|
||||
memcpy(&sf_val, &sf_byte, 1);
|
||||
float bsf = (float)sf_val;
|
||||
|
||||
// Read 8 packed bytes = 16 E2M1 values
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint8_t packed = fp4_data[m * (N / 2) + n_block * 8 + i];
|
||||
uint8_t lo_nibble = packed & 0x0F;
|
||||
uint8_t hi_nibble = (packed >> 4) & 0x0F;
|
||||
|
||||
// Low nibble
|
||||
int lo_idx = lo_nibble & 0x07;
|
||||
float lo_sign = (lo_nibble & 0x08) ? -1.0f : 1.0f;
|
||||
float lo_val = lo_sign * E2M1_LUT[lo_idx] * bsf * gsa;
|
||||
int lo_col = n_block * 16 + 2 * i;
|
||||
if (lo_col < N) {
|
||||
output[m * N + lo_col] = __float2bfloat16(lo_val);
|
||||
}
|
||||
|
||||
// High nibble
|
||||
int hi_idx = hi_nibble & 0x07;
|
||||
float hi_sign = (hi_nibble & 0x08) ? -1.0f : 1.0f;
|
||||
float hi_val = hi_sign * E2M1_LUT[hi_idx] * bsf * gsa;
|
||||
int hi_col = n_block * 16 + 2 * i + 1;
|
||||
if (hi_col < N) {
|
||||
output[m * N + hi_col] = __float2bfloat16(hi_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Selective dequant: only dequant selected rows from a larger FP4 buffer
|
||||
// This is the CSA gather path — dequant only the top-k entries needed by FMHA
|
||||
// ===========================================================================
|
||||
|
||||
__global__ void dequant_nvfp4_selective_kernel(
|
||||
const uint8_t* __restrict__ fp4_data, // (max_comp, N/2) packed E2M1
|
||||
const uint8_t* __restrict__ sf_data, // (max_comp, N/16) E4M3 block scales
|
||||
const float* __restrict__ gsa_data, // (max_comp,) FP32 global scale per row
|
||||
const int32_t* __restrict__ indices, // (K,) int32 — which rows to dequant
|
||||
__nv_bfloat16* __restrict__ output, // (K, N) BF16 output
|
||||
int K, int N
|
||||
) {
|
||||
int k = blockIdx.y; // which selected entry
|
||||
int n_block = blockIdx.x; // which 16-element block
|
||||
if (k >= K || n_block * 16 >= N) return;
|
||||
|
||||
int src_row = indices[k];
|
||||
float gsa = gsa_data[src_row];
|
||||
|
||||
int N_half = N / 2;
|
||||
int N_sf = N / 16;
|
||||
|
||||
// Read FP8 E4M3 block scale for this row and block
|
||||
uint8_t sf_byte = sf_data[src_row * N_sf + n_block];
|
||||
__nv_fp8_e4m3 sf_val;
|
||||
memcpy(&sf_val, &sf_byte, 1);
|
||||
float bsf = (float)sf_val;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint8_t packed = fp4_data[src_row * N_half + n_block * 8 + i];
|
||||
uint8_t lo_nibble = packed & 0x0F;
|
||||
uint8_t hi_nibble = (packed >> 4) & 0x0F;
|
||||
|
||||
int lo_idx = lo_nibble & 0x07;
|
||||
float lo_sign = (lo_nibble & 0x08) ? -1.0f : 1.0f;
|
||||
float lo_val = lo_sign * E2M1_LUT[lo_idx] * bsf * gsa;
|
||||
int lo_col = n_block * 16 + 2 * i;
|
||||
if (lo_col < N) {
|
||||
output[k * N + lo_col] = __float2bfloat16(lo_val);
|
||||
}
|
||||
|
||||
int hi_idx = hi_nibble & 0x07;
|
||||
float hi_sign = (hi_nibble & 0x08) ? -1.0f : 1.0f;
|
||||
float hi_val = hi_sign * E2M1_LUT[hi_idx] * bsf * gsa;
|
||||
int hi_col = n_block * 16 + 2 * i + 1;
|
||||
if (hi_col < N) {
|
||||
output[k * N + hi_col] = __float2bfloat16(hi_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// PyTorch bindings
|
||||
// ===========================================================================
|
||||
|
||||
torch::Tensor dequant_nvfp4_cuda(
|
||||
torch::Tensor fp4_data, // (M, N/2) uint8 packed E2M1
|
||||
torch::Tensor sf_data, // (M, N/16) uint8 (viewed as E4M3)
|
||||
torch::Tensor gsa_data // (M,) float32 global scale
|
||||
) {
|
||||
int M = fp4_data.size(0);
|
||||
int N = fp4_data.size(1) * 2; // N/2 packed → N actual
|
||||
TORCH_CHECK(sf_data.size(0) == M, "sf_data row count must match fp4_data");
|
||||
TORCH_CHECK(gsa_data.size(0) == M, "gsa_data row count must match fp4_data");
|
||||
|
||||
auto output = torch::zeros({M, N}, fp4_data.options().dtype(torch::kBFloat16));
|
||||
int nb = N / 16;
|
||||
dim3 grid(nb, M);
|
||||
dim3 block(16);
|
||||
|
||||
dequant_nvfp4_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
fp4_data.data_ptr<uint8_t>(),
|
||||
sf_data.data_ptr<uint8_t>(),
|
||||
gsa_data.data_ptr<float>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()),
|
||||
M, N
|
||||
);
|
||||
return output;
|
||||
}
|
||||
|
||||
torch::Tensor dequant_nvfp4_selective_cuda(
|
||||
torch::Tensor fp4_data, // (max_comp, N/2) uint8 packed E2M1
|
||||
torch::Tensor sf_data, // (max_comp, N/16) uint8 (viewed as E4M3)
|
||||
torch::Tensor gsa_data, // (max_comp,) float32 global scale
|
||||
torch::Tensor indices // (K,) int32
|
||||
) {
|
||||
int K = indices.size(0);
|
||||
int N = fp4_data.size(1) * 2; // N/2 packed → N actual
|
||||
TORCH_CHECK(indices.scalar_type() == torch::kInt32, "indices must be int32");
|
||||
|
||||
auto output = torch::zeros({K, N}, fp4_data.options().dtype(torch::kBFloat16));
|
||||
int nb = N / 16;
|
||||
dim3 grid(nb, K);
|
||||
dim3 block(16);
|
||||
|
||||
dequant_nvfp4_selective_kernel<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
fp4_data.data_ptr<uint8_t>(),
|
||||
sf_data.data_ptr<uint8_t>(),
|
||||
gsa_data.data_ptr<float>(),
|
||||
indices.data_ptr<int32_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()),
|
||||
K, N
|
||||
);
|
||||
return output;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("dequant_nvfp4", &dequant_nvfp4_cuda, "NVFP4 → BF16 dequant");
|
||||
m.def("dequant_nvfp4_selective", &dequant_nvfp4_selective_cuda, "Selective NVFP4 → BF16 dequant for CSA gather");
|
||||
}
|
||||
@@ -75,3 +75,7 @@ def preload_all():
|
||||
get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"])
|
||||
# Sampler
|
||||
get_cuda_module("sampler", ["sampler.cu"])
|
||||
# Dequant NVFP4
|
||||
get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
||||
# Fused compress + quantize
|
||||
get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])
|
||||
|
||||
170
tests/unit/test_kv_compress_quant.py
Normal file
170
tests/unit/test_kv_compress_quant.py
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test KV-1/KV-2: Fused compress + NVFP4 quantize kernel.
|
||||
|
||||
Verifies that the single-kernel compress+quantize path produces output
|
||||
with cos >= 0.999 vs the BF16 reference path.
|
||||
|
||||
Production values:
|
||||
- hd=512, m=4 (CSA), m=128 (HCA)
|
||||
- T=32 (CSA: 8 blocks), T=256 (HCA: 2 blocks)
|
||||
- kv_dim=1024 (CSA: 2*hd), kv_dim=512 (HCA: hd)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
def test_csa_compress_quant():
|
||||
"""KV-1: CSA compress + NVFP4 quantize vs BF16 reference."""
|
||||
torch.manual_seed(42)
|
||||
device = 'cuda'
|
||||
hd = 512
|
||||
m = 4
|
||||
T = 32 # 8 compressed blocks
|
||||
kv_dim = 2 * hd # CSA uses 2*hd for kv/gate projections
|
||||
|
||||
kv_proj = torch.randn(T, kv_dim, device=device) * 0.5
|
||||
gate_proj = torch.randn(T, kv_dim, device=device) * 0.3
|
||||
position_bias = torch.randn(m, kv_dim, device=device) * 0.1
|
||||
kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5
|
||||
|
||||
# BF16 reference path
|
||||
from dsv4.kernels.compressor.production_compress import csa_compress_production
|
||||
ref_bf16 = csa_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
||||
|
||||
# NVFP4 fused path
|
||||
from dsv4.kernels.compressor.production_compress import csa_compress_production_nvfp4
|
||||
fp4_data, sf, gsa = csa_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
||||
|
||||
# Dequant NVFP4 → BF16
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
||||
nvfp4_bf16 = dequant_mod.dequant_nvfp4(
|
||||
fp4_data.view(torch.uint8),
|
||||
sf.view(torch.uint8),
|
||||
gsa,
|
||||
)
|
||||
|
||||
# Compare
|
||||
ref_f = ref_bf16.float()
|
||||
nvfp4_f = nvfp4_bf16.float()
|
||||
cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item()
|
||||
max_err = (ref_f - nvfp4_f).abs().max().item()
|
||||
ref_max = ref_f.abs().max().item()
|
||||
|
||||
print(f"CSA compress + NVFP4 quantize:")
|
||||
print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}")
|
||||
print(f" fp4 shape: {tuple(fp4_data.shape)}, sf shape: {tuple(sf.shape)}, gsa shape: {tuple(gsa.shape)}")
|
||||
print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}")
|
||||
print(f" max_error: {max_err:.6f}")
|
||||
print(f" cosine: {cos:.6f}")
|
||||
assert cos >= 0.999, f"CSA compress+quant cos={cos:.6f} < 0.999"
|
||||
print(f" ✅ PASS (cos={cos:.6f})")
|
||||
|
||||
|
||||
def test_hca_compress_quant():
|
||||
"""KV-2: HCA compress + NVFP4 quantize vs BF16 reference."""
|
||||
torch.manual_seed(42)
|
||||
device = 'cuda'
|
||||
hd = 512
|
||||
m = 128
|
||||
T = 256 # 2 compressed blocks
|
||||
|
||||
kv_proj = torch.randn(T, hd, device=device) * 0.5
|
||||
gate_proj = torch.randn(T, hd, device=device) * 0.3
|
||||
position_bias = torch.randn(m, hd, device=device) * 0.1
|
||||
kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5
|
||||
|
||||
# BF16 reference path
|
||||
from dsv4.kernels.compressor.production_compress import hca_compress_production
|
||||
ref_bf16 = hca_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
||||
|
||||
# NVFP4 fused path
|
||||
from dsv4.kernels.compressor.production_compress import hca_compress_production_nvfp4
|
||||
fp4_data, sf, gsa = hca_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
||||
|
||||
# Dequant NVFP4 → BF16
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
||||
nvfp4_bf16 = dequant_mod.dequant_nvfp4(
|
||||
fp4_data.view(torch.uint8),
|
||||
sf.view(torch.uint8),
|
||||
gsa,
|
||||
)
|
||||
|
||||
# Compare
|
||||
ref_f = ref_bf16.float()
|
||||
nvfp4_f = nvfp4_bf16.float()
|
||||
cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item()
|
||||
max_err = (ref_f - nvfp4_f).abs().max().item()
|
||||
ref_max = ref_f.abs().max().item()
|
||||
|
||||
print(f"HCA compress + NVFP4 quantize:")
|
||||
print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}")
|
||||
print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}")
|
||||
print(f" max_error: {max_err:.6f}")
|
||||
print(f" cosine: {cos:.6f}")
|
||||
assert cos >= 0.999, f"HCA compress+quant cos={cos:.6f} < 0.999"
|
||||
print(f" ✅ PASS (cos={cos:.6f})")
|
||||
|
||||
|
||||
def test_dequant_selective():
|
||||
"""Test selective dequant: only top-k entries from a larger FP4 buffer."""
|
||||
torch.manual_seed(42)
|
||||
device = 'cuda'
|
||||
M = 64 # total entries in cache
|
||||
N = 512 # hd
|
||||
K = 8 # top-k
|
||||
|
||||
# Create BF16 data
|
||||
bf16_data = torch.randn(M, N, device=device, dtype=torch.bfloat16) * 2.0
|
||||
|
||||
# Quantize to NVFP4
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
fp4, sf, gsa = quantize_nvfp4_gpu_fused(bf16_data)
|
||||
|
||||
# Select K random indices
|
||||
indices = torch.randperm(M, device=device)[:K].to(torch.int32)
|
||||
|
||||
# Selective dequant
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
||||
sel_bf16 = dequant_mod.dequant_nvfp4_selective(
|
||||
fp4.view(torch.uint8),
|
||||
sf.view(torch.uint8),
|
||||
gsa,
|
||||
indices,
|
||||
)
|
||||
|
||||
# Full dequant for comparison
|
||||
full_bf16 = dequant_mod.dequant_nvfp4(
|
||||
fp4.view(torch.uint8),
|
||||
sf.view(torch.uint8),
|
||||
gsa,
|
||||
)
|
||||
|
||||
# Compare selected entries
|
||||
ref = full_bf16[indices.cpu().numpy()].to(device)
|
||||
cos = torch.nn.functional.cosine_similarity(sel_bf16.float().flatten(), ref.float().flatten(), dim=0).item()
|
||||
|
||||
print(f"Selective dequant (M={M}, K={K}, N={N}):")
|
||||
print(f" sel shape: {tuple(sel_bf16.shape)}")
|
||||
print(f" cosine vs full dequant: {cos:.6f}")
|
||||
assert cos >= 0.9999, f"Selective dequant cos={cos:.6f} < 0.9999"
|
||||
print(f" ✅ PASS (cos={cos:.6f})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("KV-1/KV-2: Compress + NVFP4 Quantize Tests")
|
||||
print("Production values: hd=512, m=4 (CSA), m=128 (HCA)")
|
||||
print("=" * 60)
|
||||
|
||||
test_csa_compress_quant()
|
||||
print()
|
||||
test_hca_compress_quant()
|
||||
print()
|
||||
test_dequant_selective()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL TESTS PASSED")
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user