Files
nvfp4-megamoe-kernel/tests/test_nvfp4_attention_b200.py

253 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Test NVFP4 attention: quantize Q and K, GEMM in NVFP4, softmax in BF16.
Step 1: Verify NVFP4 quantize/dequant roundtrip for attention
Step 2: Q×K^T using CuTeDSL NVFP4 GEMM
Step 3: Softmax + attn×V
Step 4: Full pipeline with real weights, compare to BF16 SDPA
Usage (on B200):
cd /root/nvfp4-megamoe-kernel
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_nvfp4_attention_b200.py
"""
import sys, os, json, torch, torch.nn.functional as F, math
from safetensors import safe_open
REPO = "/root/nvfp4-megamoe-kernel"
sys.path.insert(0, REPO)
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEV = "cuda:0"
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
EPS = 1e-6; WINDOW = 8192; SCALE = HD ** -0.5
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
_cache = {}
def P(k, wm, md):
if k in _cache: return _cache[k]
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
t = f.get_tensor(k)
_cache[k] = t
return t
def dequant(w, sf, gs):
d = w.device; lut = E2M1.to(d)
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
O, I2 = w.shape; I = I2*2
u = torch.empty(O, I, dtype=torch.float32, device=d)
u[:,0::2] = lo; u[:,1::2] = hi
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
return (u * bs * gs).to(torch.bfloat16)
def rms(x, w, eps=1e-6):
v = x.float().pow(2).mean(-1, keepdim=True)
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
if fused and gs_t.numel() == 2:
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
if g1 != g2:
s32 = s.float(); sp = lw[0] if lw else outf//2
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
else:
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r
def apply_gptj_rope(x, positions, cos_sin, nope, rope):
if rope == 0 or x.numel() == 0: return x
half = rope // 2
cos = cos_sin[positions, :half].to(x.dtype)
sin = cos_sin[positions, half:].to(x.dtype)
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
x_rope = x[..., nope:].clone()
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope:][..., 0::2] = even * cos - odd * sin
out[..., nope:][..., 1::2] = even * sin + odd * cos
return out
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
half = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
def bf16_full_attention(q, kv, scale):
"""BF16 reference: full self-attention with causal mask."""
T, NH, HD = q.shape
q_2d = q.reshape(T * NH, HD)
kv_expanded = kv.unsqueeze(1).expand(-1, NH, -1).contiguous()
k_2d = kv_expanded.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD)
v_2d = k_2d.clone()
scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale
query_pos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH)
kv_pos = torch.arange(T, device=q.device).unsqueeze(0)
causal = kv_pos <= query_pos.unsqueeze(1)
scores = scores.squeeze(1).masked_fill(~causal, float('-inf'))
weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1)
return out.reshape(T, NH, HD)
def nvfp4_qk_attention(q, kv, scale):
"""NVFP4 attention: quantize Q and K for Q×K^T, then BF16 softmax + attn×V.
Key insight: Q×K^T is (T*NH, HD) × (HD, T) = (T*NH, T).
This is a standard GEMM that CuTeDSL can handle.
We quantize Q as the "activation" and K^T as the "weight".
"""
from cutedsl.bridge import quantize_to_nvfp4, quantize_activation_nvfp4
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
T, NH, HD = q.shape
device = q.device
# Q as activation: (T*NH, HD) → NVFP4
q_2d = q.reshape(T * NH, HD)
q_fp4, q_sf, q_gs = quantize_to_nvfp4(q_2d) # (T*NH, HD//2), (T*NH, HD//16), scalar
# K as weight: (T, HD) → transpose to (HD, T), quantize as weight
# In our framework, "weight" means quantized along K dim
kv_T = kv.T.contiguous() # (HD, T)
w_fp4, w_sf, w_gs = quantize_to_nvfp4(kv_T) # (HD//2, T), (HD//16, T), scalar
# Use CuTeDSLNvfp4Linear runner for Q×K^T GEMM
# in_features=HD, out_features=T
# Q is "activation" side, K^T is "weight" side
M = T * NH
K = HD
N = T
# Create runner for this specific (M, K, N) combination
runner = CuTeDSLNvfp4Linear(
in_features=K, out_features=N, max_num_tokens=M, device=str(device)
)
# Weight is kv_T: set up as (N, K//2) in N-major (standard row-major)
# runner expects: weight fp4 is (N, K//2), weight sf is (N, K//16)
# Our w_fp4 from quantize_to_nvfp4(kv_T) is (K//2, T) — that's (K_packed, N)
# Need to transpose to (N, K_packed)
w_fp4_loaded = w_fp4.T.contiguous() # (T, HD//2) = (N, K_packed)
w_sf_loaded = w_sf.T.contiguous() # (T, HD//16) = (N, K_sf)
runner.fp4 = [w_fp4_loaded]
runner.sf = [w_sf_loaded]
runner.gs = [w_gs]
runner.finalize_weights()
runner._ensure_initialized()
# Run: Q×K^T
# q_2d is (M, K) BF16, runner produces (M, N) BF16
scores = runner.run(q_2d) * scale # (T*NH, T)
# Causal mask
query_pos = torch.arange(T, device=device).unsqueeze(1).repeat(1, NH).reshape(T * NH)
kv_pos = torch.arange(T, device=device).unsqueeze(0)
causal = kv_pos <= query_pos.unsqueeze(1)
scores = scores.masked_fill(~causal, float('-inf'))
# Softmax in BF16 (must be full precision for numerical stability)
weights = F.softmax(scores.float(), dim=-1).to(q.dtype) # (T*NH, T)
# attn×V: (T*NH, T) × (T, HD) → (T*NH, HD)
# V = kv (shared, BF16) — no quantization needed here since attn weights are already BF16
out = torch.matmul(weights, kv) # (T*NH, HD)
return out.reshape(T, NH, HD)
def main():
torch.cuda.set_device(0)
torch.manual_seed(42)
print("=" * 70)
print(" NVFP4 Attention Kernel Test")
print("=" * 70)
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
G = lambda k: P(k, wm, MODEL).to(DEV)
p = "model.layers.0"; a = f"{p}.self_attn"
# Load weights
emb = G("model.embed_tokens.weight")
anorm = G(f"{p}.input_layernorm.weight")
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
woa = G(f"{a}.o_a_proj.weight")
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
sinks = G(f"{a}.sinks")
# BF16 references
qa_bf16 = dequant(qa_w, qa_sf, qa_gs.item())
qb_bf16 = dequant(qb_w, qb_sf, qb_gs.item())
kv_bf16 = dequant(kv_w, kv_sf, kv_gs.item())
wob_bf16 = dequant(wob_w, wob_sf, wob_gs.item())
# CuTeDSL runners
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
# Input
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV)
NT = len(token_ids)
cos_sin = build_cos_sin(max_pos=WINDOW + 256).to(DEV)
positions = torch.arange(NT, dtype=torch.int64, device=DEV)
print(f" Input: {NT} tokens, {NH} heads, HD={HD}")
with torch.no_grad():
hidden = emb[token_ids]
normed = rms(hidden, anorm, EPS)
# Projections
qa_cute = r_qa.run(normed)
kv_cute = r_kv.run(normed)
qa_n = rms(qa_cute, qn, EPS)
kv_n = rms(kv_cute, kvn, EPS)
q_cute = r_qb.run(qa_n).view(NT, NH, HD)
q_rope = apply_gptj_rope(q_cute, positions, cos_sin, NOPE, ROPE)
# ── BF16 reference ────────────────────────────────────────────
print("\n--- Step 1: BF16 reference attention ---")
o_bf16 = bf16_full_attention(q_rope, kv_n, SCALE)
print(f" BF16 attention output: amax={o_bf16.amax():.4f} NaN={torch.isnan(o_bf16).any()}")
# ── NVFP4 Q×K^T attention ────────────────────────────────────
print("\n--- Step 2: NVFP4 Q×K^T attention ---")
try:
o_nvfp4 = nvfp4_qk_attention(q_rope, kv_n, SCALE)
print(f" NVFP4 attention output: amax={o_nvfp4.amax():.4f} NaN={torch.isnan(o_nvfp4).any()}")
c = F.cosine_similarity(o_nvfp4.flatten().unsqueeze(0).float(), o_bf16.flatten().unsqueeze(0).float()).item()
print(f" NVFP4 vs BF16 cosine: {c:.6f} {'' if c>=0.98 else ''}")
except Exception as e:
print(f" ERROR: {e}")
import traceback; traceback.print_exc()
print("\n" + "=" * 70)
print(" Done")
print("=" * 70)
if __name__ == "__main__":
main()