357 lines
15 KiB
Python
357 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
DeepSeek-V4 KV Cache Kernel — NVFP4 Compressed Storage
|
|
|
|
Architecture:
|
|
- SWA cache: (T, HD=512) per token, stored as fp8_e4m3 (512 bytes per token)
|
|
- CSA cache (C4A): every 4th token stored, (T//4, HD) fp8 (128 bytes per token)
|
|
- HCA cache (C128A): every 128th token stored, (T//128, HD) fp8 (4 bytes per token)
|
|
|
|
The KV latent is (1, HD=512) — single KV head. After kv_norm + RoPE,
|
|
it's quantized to fp8_e4m3 and stored in the paged KV cache.
|
|
|
|
For CSA/HCA layers, the compressor further reduces the cache:
|
|
- The indexer finds top-k positions in the compressed cache
|
|
- Attention only attends to those positions
|
|
|
|
This kernel tests:
|
|
1. KV quantization: BF16 → fp8_e4m3 (with per-token scale)
|
|
2. KV dequantization: fp8_e4m3 → BF16
|
|
3. RoPE on dequantized KV (applied after dequant)
|
|
4. Full attention using the cache
|
|
5. Compressed cache (CSA/HCA) storage and retrieval
|
|
|
|
Usage (on B200):
|
|
cd /root/nvfp4-megamoe-kernel
|
|
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_kv_cache_b200.py
|
|
"""
|
|
|
|
import sys, os, json, torch, torch.nn.functional as F, math
|
|
from safetensors import safe_open
|
|
|
|
REPO = "/root/nvfp4-megamoe-kernel"
|
|
sys.path.insert(0, REPO)
|
|
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
DEV = "cuda:0"
|
|
|
|
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
|
|
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
|
|
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
|
|
|
|
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
|
|
|
_cache = {}
|
|
def P(k, wm, md):
|
|
if k in _cache: return _cache[k]
|
|
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
|
t = f.get_tensor(k)
|
|
_cache[k] = t
|
|
return t
|
|
|
|
def dequant(w, sf, gs):
|
|
d = w.device; lut = E2M1.to(d)
|
|
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
|
|
O, I2 = w.shape; I = I2*2
|
|
u = torch.empty(O, I, dtype=torch.float32, device=d)
|
|
u[:,0::2] = lo; u[:,1::2] = hi
|
|
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
|
|
return (u * bs * gs).to(torch.bfloat16)
|
|
|
|
def rms(x, w, eps=1e-6):
|
|
v = x.float().pow(2).mean(-1, keepdim=True)
|
|
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
|
|
|
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
|
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
|
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
|
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
|
s = s.permute(1,0).contiguous()
|
|
if fused and gs_t.numel() == 2:
|
|
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
|
|
if g1 != g2:
|
|
s32 = s.float(); sp = lw[0] if lw else outf//2
|
|
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
|
|
else:
|
|
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
|
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
|
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
|
r.finalize_weights(); r._ensure_initialized()
|
|
return r
|
|
|
|
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
|
|
half = rope_dim // 2
|
|
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
|
|
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
|
|
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
|
|
|
|
def apply_gptj_rope(x, positions, cos_sin, nope, rope):
|
|
if rope == 0 or x.numel() == 0: return x
|
|
half = rope // 2
|
|
cos = cos_sin[positions, :half].to(x.dtype)
|
|
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
|
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
|
x_rope = x[..., nope:].clone()
|
|
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
|
out = x.clone()
|
|
out[..., nope:][..., 0::2] = even * cos - odd * sin
|
|
out[..., nope:][..., 1::2] = even * sin + odd * cos
|
|
return out
|
|
|
|
|
|
# ── KV Cache Kernels ────────────────────────────────────────────────
|
|
|
|
def kv_quantize_fp8(kv_bf16):
|
|
"""Quantize KV latent to fp8_e4m3 with per-token scale.
|
|
|
|
kv_bf16: (T, HD) BF16
|
|
Returns: (T, HD) fp8, (T, 1) per-token scale (BF16)
|
|
"""
|
|
# Per-token absmax
|
|
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
|
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device) # e4m3 max
|
|
scale = fp8_max / amax # (T, 1)
|
|
kv_scaled = kv_bf16.float() * scale
|
|
kv_fp8 = kv_scaled.to(torch.float8_e4m3fn)
|
|
# Store inverse scale for dequant
|
|
inv_scale = amax / fp8_max # (T, 1) — multiply by this to recover
|
|
return kv_fp8, inv_scale.to(torch.bfloat16)
|
|
|
|
|
|
def kv_dequantize_fp8(kv_fp8, inv_scale):
|
|
"""Dequantize fp8 KV back to BF16.
|
|
|
|
kv_fp8: (T, HD) fp8_e4m3
|
|
inv_scale: (T, 1) per-token scale
|
|
Returns: (T, HD) BF16
|
|
"""
|
|
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
|
|
|
|
|
|
def kv_quantize_nvfp4(kv_bf16):
|
|
"""Quantize KV latent to NVFP4 using CuTeDSL quantize_to_nvfp4.
|
|
|
|
More aggressive compression: 2x smaller than fp8 (4 bits vs 8 bits per element).
|
|
|
|
kv_bf16: (T, HD) BF16
|
|
Returns: (T, HD//2) fp4, (T, HD//16) sf, scalar gs
|
|
"""
|
|
from cutedsl.bridge import quantize_to_nvfp4
|
|
return quantize_to_nvfp4(kv_bf16)
|
|
|
|
|
|
def kv_dequantize_nvfp4(kv_fp4, kv_sf, kv_gs, head_dim=HD):
|
|
"""Dequantize NVFP4 KV back to BF16.
|
|
|
|
kv_fp4: (T, HD//2) fp4 (as float4_e2m1fn_x2 viewed as uint8)
|
|
kv_sf: (T, HD//16) fp8 block scales
|
|
kv_gs: scalar global scale
|
|
"""
|
|
device = kv_fp4.device
|
|
lut = E2M1.to(device)
|
|
packed = kv_fp4.view(torch.uint8)
|
|
lo = lut[(packed & 0xF).long()]
|
|
hi = lut[((packed >> 4) & 0xF).long()]
|
|
T = kv_fp4.shape[0]
|
|
u = torch.empty(T, head_dim, dtype=torch.float32, device=device)
|
|
u[:, 0::2] = lo
|
|
u[:, 1::2] = hi
|
|
sf_exp = kv_sf.float().repeat_interleave(16, dim=1)[:, :head_dim]
|
|
return (u * sf_exp * kv_gs).to(torch.bfloat16)
|
|
|
|
|
|
def paged_kv_write(kv_fp8, slot_mapping, cache, block_size):
|
|
"""Write KV into paged cache.
|
|
|
|
kv_fp8: (T, HD) fp8 to write
|
|
slot_mapping: (T,) slot indices (position in flat cache)
|
|
cache: (num_blocks, block_size, HD) fp8 cache tensor
|
|
block_size: tokens per block
|
|
"""
|
|
for t in range(kv_fp8.shape[0]):
|
|
slot = slot_mapping[t].item()
|
|
block_idx = slot // block_size
|
|
offset = slot % block_size
|
|
if block_idx < cache.shape[0]:
|
|
cache[block_idx, offset] = kv_fp8[t]
|
|
|
|
|
|
def paged_kv_read(slot_mapping, cache, block_size, num_tokens):
|
|
"""Read KV from paged cache.
|
|
|
|
Returns: (num_tokens, HD) fp8
|
|
"""
|
|
device = cache.device
|
|
HD = cache.shape[-1]
|
|
kv = torch.zeros(num_tokens, HD, dtype=cache.dtype, device=device)
|
|
for t in range(num_tokens):
|
|
slot = slot_mapping[t].item()
|
|
block_idx = slot // block_size
|
|
offset = slot % block_size
|
|
if block_idx < cache.shape[0]:
|
|
kv[t] = cache[block_idx, offset]
|
|
return kv
|
|
|
|
|
|
def main():
|
|
torch.cuda.set_device(0)
|
|
torch.manual_seed(42)
|
|
|
|
print("=" * 70)
|
|
print(" DeepSeek-V4 KV Cache Kernel Test")
|
|
print(" fp8 and NVFP4 quantization for paged KV cache")
|
|
print("=" * 70)
|
|
|
|
# Load real weights
|
|
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
|
wm = json.load(f)["weight_map"]
|
|
G = lambda k: P(k, wm, MODEL).to(DEV)
|
|
|
|
p = "model.layers.0"; a = f"{p}.self_attn"
|
|
emb = G("model.embed_tokens.weight")
|
|
anorm = G(f"{p}.input_layernorm.weight")
|
|
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
|
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
|
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
|
|
|
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
|
|
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
|
|
|
|
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
|
|
|
|
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV)
|
|
NT = len(token_ids)
|
|
positions = torch.arange(NT, dtype=torch.int64, device=DEV)
|
|
|
|
with torch.no_grad():
|
|
hidden = emb[token_ids]
|
|
normed = rms(hidden, anorm, EPS)
|
|
kv_bf16 = r_kv.run(normed)
|
|
kv_bf16 = rms(kv_bf16, kvn, EPS)
|
|
|
|
# ── Test 1: FP8 KV quantize/dequant roundtrip ────────────────
|
|
print("\n--- Test 1: FP8 KV quantize/dequant ---")
|
|
kv_fp8, inv_scale = kv_quantize_fp8(kv_bf16)
|
|
kv_recovered = kv_dequantize_fp8(kv_fp8, inv_scale)
|
|
c = F.cosine_similarity(kv_bf16.flatten().unsqueeze(0).float(), kv_recovered.flatten().unsqueeze(0).float()).item()
|
|
print(f" FP8 roundtrip cosine: {c:.6f} {'✅' if c>=0.99 else '❌'}")
|
|
print(f" FP8 cache size: {kv_fp8.numel()} bytes (vs {kv_bf16.numel()*2} BF16)")
|
|
|
|
# ── Test 2: NVFP4 KV quantize/dequant roundtrip ──────────────
|
|
print("\n--- Test 2: NVFP4 KV quantize/dequant ---")
|
|
try:
|
|
kv_nfp4, kv_nsf, kv_ngs = kv_quantize_nvfp4(kv_bf16)
|
|
kv_n_recovered = kv_dequantize_nvfp4(kv_nfp4, kv_nsf, kv_ngs)
|
|
c = F.cosine_similarity(kv_bf16.flatten().unsqueeze(0).float(), kv_n_recovered.flatten().unsqueeze(0).float()).item()
|
|
print(f" NVFP4 roundtrip cosine: {c:.6f} {'✅' if c>=0.98 else '❌'}")
|
|
print(f" NVFP4 cache size: {kv_nfp4.view(torch.uint8).numel()} bytes (vs {kv_bf16.numel()*2} BF16, {kv_fp8.numel()} FP8)")
|
|
except Exception as e:
|
|
print(f" NVFP4 quantize failed: {e}")
|
|
|
|
# ── Test 3: Paged KV cache write/read with FP8 ───────────────
|
|
print("\n--- Test 3: Paged KV cache (FP8) ---")
|
|
block_size = 256
|
|
num_blocks = 64
|
|
cache = torch.zeros(num_blocks, block_size, HD, dtype=torch.float8_e4m3fn, device=DEV)
|
|
# Slot mapping: position → flat slot (simplified: slot = position)
|
|
slot_mapping = positions # (NT,)
|
|
|
|
# Write KV into cache
|
|
paged_kv_write(kv_fp8, slot_mapping, cache, block_size)
|
|
|
|
# Read back
|
|
kv_read = paged_kv_read(slot_mapping, cache, block_size, NT)
|
|
c = F.cosine_similarity(kv_fp8.flatten().unsqueeze(0).float(), kv_read.flatten().unsqueeze(0).float()).item()
|
|
print(f" Paged read back cosine: {c:.6f} {'✅' if c>=0.999 else '❌'}")
|
|
|
|
# ── Test 4: Apply RoPE after dequant ─────────────────────────
|
|
print("\n--- Test 4: RoPE on dequantized KV ---")
|
|
# KV needs RoPE applied at the positions it was stored at
|
|
kv_with_rope = apply_gptj_rope(kv_recovered.unsqueeze(1), positions, cos_sin, NOPE, ROPE).squeeze(1)
|
|
print(f" KV+RoPE: amax={kv_with_rope.amax():.4f} NaN={torch.isnan(kv_with_rope).any()}")
|
|
|
|
# ── Test 5: Full attention with FP8 KV cache ─────────────────
|
|
print("\n--- Test 5: Full attention pipeline with FP8 KV cache ---")
|
|
qa_bf16_ref = dequant(qa_w, qa_sf, qa_gs.item())
|
|
qb_bf16_ref = dequant(
|
|
G(f"{a}.q_b_proj.weight"),
|
|
G(f"{a}.q_b_proj.weight_scale"),
|
|
G(f"{a}.q_b_proj.weight_scale_2").item()
|
|
)
|
|
kv_bf16_ref = dequant(kv_w, kv_sf, kv_gs.item())
|
|
|
|
r_qb = make_runner(
|
|
G(f"{a}.q_b_proj.weight"),
|
|
G(f"{a}.q_b_proj.weight_scale"),
|
|
G(f"{a}.q_b_proj.weight_scale_2"),
|
|
QL, G(f"{a}.q_b_proj.weight").shape[0]
|
|
)
|
|
|
|
# Full BF16 reference
|
|
qa_ref = normed @ qa_bf16_ref.T
|
|
kv_ref = normed @ kv_bf16_ref.T
|
|
qa_n_ref = rms(qa_ref, qn, EPS)
|
|
kv_n_ref = rms(kv_ref, kvn, EPS)
|
|
q_ref = (qa_n_ref @ qb_bf16_ref.T).view(NT, NH, HD)
|
|
q_rope_ref = apply_gptj_rope(q_ref, positions, cos_sin, NOPE, ROPE)
|
|
|
|
# BF16 causal attention using dequantized FP8 KV cache
|
|
kv_from_cache = kv_dequantize_fp8(kv_read, inv_scale)
|
|
kv_from_cache_rope = apply_gptj_rope(kv_from_cache.unsqueeze(1), positions, cos_sin, NOPE, ROPE).squeeze(1)
|
|
|
|
# Full attention with cached KV
|
|
T, NH_t, HD_t = q_rope_ref.shape
|
|
q_2d = q_rope_ref.reshape(T * NH_t, HD_t)
|
|
kv_exp = kv_from_cache_rope.unsqueeze(1).expand(-1, NH_t, -1).contiguous()
|
|
k_2d = kv_exp.permute(1, 0, 2).unsqueeze(1).expand(NH_t, T, T, -1).contiguous().reshape(T * NH_t, T, HD_t)
|
|
scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * SCALE
|
|
qpos = torch.arange(T, device=DEV).unsqueeze(1).repeat(1, NH_t).reshape(T * NH_t)
|
|
kpos = torch.arange(T, device=DEV).unsqueeze(0)
|
|
causal = kpos <= qpos.unsqueeze(1)
|
|
scores = scores.squeeze(1).masked_fill(~causal, float('-inf'))
|
|
weights = F.softmax(scores.float(), dim=-1).to(q_rope_ref.dtype)
|
|
v_2d = k_2d.clone()
|
|
out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1).reshape(T, NH_t, HD_t)
|
|
|
|
# BF16 attention with original (no cache) KV
|
|
kv_exp2 = kv_n_ref.unsqueeze(1).expand(-1, NH_t, -1).contiguous()
|
|
k_2d2 = kv_exp2.permute(1, 0, 2).unsqueeze(1).expand(NH_t, T, T, -1).contiguous().reshape(T * NH_t, T, HD_t)
|
|
scores2 = torch.matmul(q_2d.unsqueeze(1), k_2d2.transpose(-1, -2)) * SCALE
|
|
scores2 = scores2.squeeze(1).masked_fill(~causal, float('-inf'))
|
|
weights2 = F.softmax(scores2.float(), dim=-1).to(q_rope_ref.dtype)
|
|
out2 = torch.matmul(weights2.unsqueeze(1), v_2d).squeeze(1).reshape(T, NH_t, HD_t)
|
|
|
|
c = F.cosine_similarity(out.flatten().unsqueeze(0).float(), out2.flatten().unsqueeze(0).float()).item()
|
|
print(f" FP8 cached KV vs BF16 KV attention cosine: {c:.6f} {'✅' if c>=0.98 else '❌'}")
|
|
|
|
# ── Test 6: CSA compressed cache (cr=4) ──────────────────────
|
|
print("\n--- Test 6: CSA compressed cache (cr=4) ---")
|
|
cr = 4
|
|
# Store every 4th token in the compressed cache
|
|
compressed_positions = positions[::cr] # every 4th position
|
|
compressed_kv = kv_fp8[::cr] # (T//4, HD) fp8
|
|
compressed_inv_scale = inv_scale[::cr]
|
|
print(f" Compressed KV shape: {compressed_kv.shape} (from {kv_fp8.shape})")
|
|
print(f" Compression ratio: {kv_fp8.shape[0] / compressed_kv.shape[0]:.0f}x")
|
|
|
|
# Dequant compressed KV
|
|
compressed_kv_bf16 = kv_dequantize_fp8(compressed_kv, compressed_inv_scale)
|
|
c = F.cosine_similarity(kv_bf16[::cr].flatten().unsqueeze(0).float(), compressed_kv_bf16.flatten().unsqueeze(0).float()).item()
|
|
print(f" Compressed KV dequant cosine: {c:.6f} {'✅' if c>=0.99 else '❌'}")
|
|
|
|
# ── Test 7: HCA compressed cache (cr=128) ────────────────────
|
|
print("\n--- Test 7: HCA compressed cache (cr=128) ---")
|
|
cr = 128
|
|
compressed_positions_128 = positions[::cr]
|
|
compressed_kv_128 = kv_fp8[::cr] if len(kv_fp8) >= cr else kv_fp8[:1]
|
|
compressed_inv_scale_128 = inv_scale[::cr] if len(inv_scale) >= cr else inv_scale[:1]
|
|
print(f" HCA compressed KV shape: {compressed_kv_128.shape}")
|
|
print(f" Tokens in HCA cache: {compressed_kv_128.shape[0]} (from {NT})")
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f" DONE — KV cache kernels tested")
|
|
print(f"{'='*70}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|