Files
nvfp4-megamoe-kernel/tests/archive/test_nvfp4_attention_b200.py
biondizzle 9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

256 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Test NVFP4 attention: quantize Q and K, GEMM in NVFP4, softmax in BF16.
Step 1: Verify NVFP4 quantize/dequant roundtrip for attention
Step 2: Q×K^T using CuTeDSL NVFP4 GEMM
Step 3: Softmax + attn×V
Step 4: Full pipeline with real weights, compare to BF16 SDPA
Usage (on B200):
cd /root/nvfp4-megamoe-kernel
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_nvfp4_attention_b200.py
"""
import sys, os, json, torch, torch.nn.functional as F, math
from safetensors import safe_open
REPO = "/root/nvfp4-megamoe-kernel"
sys.path.insert(0, REPO)
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEV = "cuda:0"
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
EPS = 1e-6; WINDOW = 8192; SCALE = HD ** -0.5
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
_cache = {}
def P(k, wm, md):
if k in _cache: return _cache[k]
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
t = f.get_tensor(k)
_cache[k] = t
return t
def dequant(w, sf, gs):
d = w.device; lut = E2M1.to(d)
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
O, I2 = w.shape; I = I2*2
u = torch.empty(O, I, dtype=torch.float32, device=d)
u[:,0::2] = lo; u[:,1::2] = hi
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
return (u * bs * gs).to(torch.bfloat16)
def rms(x, w, eps=1e-6):
v = x.float().pow(2).mean(-1, keepdim=True)
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
from dsv4.layers.linear import Nvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
if fused and gs_t.numel() == 2:
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
if g1 != g2:
s32 = s.float(); sp = lw[0] if lw else outf//2
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
else:
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = Nvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r
def apply_gptj_rope(x, positions, cos_sin, nope, rope):
if rope == 0 or x.numel() == 0: return x
half = rope // 2
cos = cos_sin[positions, :half].to(x.dtype)
sin = cos_sin[positions, half:].to(x.dtype)
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
x_rope = x[..., nope:].clone()
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope:][..., 0::2] = even * cos - odd * sin
out[..., nope:][..., 1::2] = even * sin + odd * cos
return out
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
half = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
def bf16_full_attention(q, kv, scale):
"""BF16 reference: full self-attention with causal mask."""
T, NH, HD = q.shape
q_2d = q.reshape(T * NH, HD)
kv_expanded = kv.unsqueeze(1).expand(-1, NH, -1).contiguous()
k_2d = kv_expanded.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD)
v_2d = k_2d.clone()
scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale
query_pos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH)
kv_pos = torch.arange(T, device=q.device).unsqueeze(0)
causal = kv_pos <= query_pos.unsqueeze(1)
scores = scores.squeeze(1).masked_fill(~causal, float('-inf'))
weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1)
return out.reshape(T, NH, HD)
def nvfp4_qk_attention(q, kv, scale):
"""NVFP4 attention: quantize Q and K for Q×K^T, then BF16 softmax + attn×V.
Key insight: Q×K^T is (T*NH, HD) × (HD, T) = (T*NH, T).
This is a standard GEMM that CuTeDSL can handle.
We quantize Q as the "activation" and K^T as the "weight".
"""
from dsv4.ops.quantize import (
quantize_to_nvfp4,
quantize_activation_nvfp4,
)
from dsv4.layers.linear import Nvfp4Linear
T, NH, HD = q.shape
device = q.device
# Q as activation: (T*NH, HD) → NVFP4
q_2d = q.reshape(T * NH, HD)
q_fp4, q_sf, q_gs = quantize_to_nvfp4(q_2d) # (T*NH, HD//2), (T*NH, HD//16), scalar
# K as weight: (T, HD) → transpose to (HD, T), quantize as weight
# In our framework, "weight" means quantized along K dim
kv_T = kv.T.contiguous() # (HD, T)
w_fp4, w_sf, w_gs = quantize_to_nvfp4(kv_T) # (HD//2, T), (HD//16, T), scalar
# Use Nvfp4Linear runner for Q×K^T GEMM
# in_features=HD, out_features=T
# Q is "activation" side, K^T is "weight" side
M = T * NH
K = HD
N = T
# Create runner for this specific (M, K, N) combination
runner = Nvfp4Linear(
in_features=K, out_features=N, max_num_tokens=M, device=str(device)
)
# Weight is kv_T: set up as (N, K//2) in N-major (standard row-major)
# runner expects: weight fp4 is (N, K//2), weight sf is (N, K//16)
# Our w_fp4 from quantize_to_nvfp4(kv_T) is (K//2, T) — that's (K_packed, N)
# Need to transpose to (N, K_packed)
w_fp4_loaded = w_fp4.T.contiguous() # (T, HD//2) = (N, K_packed)
w_sf_loaded = w_sf.T.contiguous() # (T, HD//16) = (N, K_sf)
runner.fp4 = [w_fp4_loaded]
runner.sf = [w_sf_loaded]
runner.gs = [w_gs]
runner.finalize_weights()
runner._ensure_initialized()
# Run: Q×K^T
# q_2d is (M, K) BF16, runner produces (M, N) BF16
scores = runner.run(q_2d) * scale # (T*NH, T)
# Causal mask
query_pos = torch.arange(T, device=device).unsqueeze(1).repeat(1, NH).reshape(T * NH)
kv_pos = torch.arange(T, device=device).unsqueeze(0)
causal = kv_pos <= query_pos.unsqueeze(1)
scores = scores.masked_fill(~causal, float('-inf'))
# Softmax in BF16 (must be full precision for numerical stability)
weights = F.softmax(scores.float(), dim=-1).to(q.dtype) # (T*NH, T)
# attn×V: (T*NH, T) × (T, HD) → (T*NH, HD)
# V = kv (shared, BF16) — no quantization needed here since attn weights are already BF16
out = torch.matmul(weights, kv) # (T*NH, HD)
return out.reshape(T, NH, HD)
def main():
torch.cuda.set_device(0)
torch.manual_seed(42)
print("=" * 70)
print(" NVFP4 Attention Kernel Test")
print("=" * 70)
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
G = lambda k: P(k, wm, MODEL).to(DEV)
p = "model.layers.0"; a = f"{p}.self_attn"
# Load weights
emb = G("model.embed_tokens.weight")
anorm = G(f"{p}.input_layernorm.weight")
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
woa = G(f"{a}.o_a_proj.weight")
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
sinks = G(f"{a}.sinks")
# BF16 references
qa_bf16 = dequant(qa_w, qa_sf, qa_gs.item())
qb_bf16 = dequant(qb_w, qb_sf, qb_gs.item())
kv_bf16 = dequant(kv_w, kv_sf, kv_gs.item())
wob_bf16 = dequant(wob_w, wob_sf, wob_gs.item())
# CuTeDSL runners
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
# Input
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV)
NT = len(token_ids)
cos_sin = build_cos_sin(max_pos=WINDOW + 256).to(DEV)
positions = torch.arange(NT, dtype=torch.int64, device=DEV)
print(f" Input: {NT} tokens, {NH} heads, HD={HD}")
with torch.no_grad():
hidden = emb[token_ids]
normed = rms(hidden, anorm, EPS)
# Projections
qa_cute = r_qa.run(normed)
kv_cute = r_kv.run(normed)
qa_n = rms(qa_cute, qn, EPS)
kv_n = rms(kv_cute, kvn, EPS)
q_cute = r_qb.run(qa_n).view(NT, NH, HD)
q_rope = apply_gptj_rope(q_cute, positions, cos_sin, NOPE, ROPE)
# ── BF16 reference ────────────────────────────────────────────
print("\n--- Step 1: BF16 reference attention ---")
o_bf16 = bf16_full_attention(q_rope, kv_n, SCALE)
print(f" BF16 attention output: amax={o_bf16.amax():.4f} NaN={torch.isnan(o_bf16).any()}")
# ── NVFP4 Q×K^T attention ────────────────────────────────────
print("\n--- Step 2: NVFP4 Q×K^T attention ---")
try:
o_nvfp4 = nvfp4_qk_attention(q_rope, kv_n, SCALE)
print(f" NVFP4 attention output: amax={o_nvfp4.amax():.4f} NaN={torch.isnan(o_nvfp4).any()}")
c = F.cosine_similarity(o_nvfp4.flatten().unsqueeze(0).float(), o_bf16.flatten().unsqueeze(0).float()).item()
print(f" NVFP4 vs BF16 cosine: {c:.6f} {'' if c>=0.98 else ''}")
except Exception as e:
print(f" ERROR: {e}")
import traceback; traceback.print_exc()
print("\n" + "=" * 70)
print(" Done")
print("=" * 70)
if __name__ == "__main__":
main()