Files
nvfp4-megamoe-kernel/tests/test_attention_path_b200.py

268 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Pinpoint which vLLM attention component fails on B200.
Runs each step of the attention forward pass individually:
1. fused_wqa_wkv (CuTeDSL) ✅ already verified
2. q_norm + kv_norm (RMS) — trivial
3. wq_b (CuTeDSL) ✅ already verified
4. RoPE (pure PyTorch reference)
5. FlashMLA attention — THE SUSPECT
6. wo_a BMM (BF16)
7. wo_b (CuTeDSL) ✅ already verified
Then builds a FAKE attention output (random but reasonable) and runs
steps 6+7 to verify the post-attention path works.
Usage (on B200):
source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate
python3 tests/test_attention_path_b200.py
"""
import sys, os, json, torch, torch.nn.functional as F
from safetensors import safe_open
REPO = "/root/nvfp4-megamoe-kernel"
sys.path.insert(0, REPO)
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEV = "cuda:0"
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
EPS = 1e-6
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
_cache = {}
def P(k, wm, md):
if k in _cache: return _cache[k]
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
t = f.get_tensor(k)
_cache[k] = t
return t
def dequant(w, sf, gs):
d = w.device; lut = E2M1.to(d)
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
O, I2 = w.shape; I = I2*2
u = torch.empty(O, I, dtype=torch.float32, device=d)
u[:,0::2] = lo; u[:,1::2] = hi
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
return (u * bs * gs).to(torch.bfloat16)
def rms(x, w, eps=1e-6):
v = x.float().pow(2).mean(-1, keepdim=True)
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
if fused and gs_t.numel() == 2:
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
if g1 != g2:
s32 = s.float(); sp = lw[0] if lw else outf//2
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
else:
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
half = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
def main():
torch.cuda.set_device(0)
torch.manual_seed(42)
print("=" * 70)
print(" Attention Path Test: Pinpoint FlashMLA Failure")
print("=" * 70)
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
G = lambda k: P(k, wm, MODEL).to(DEV)
p = "model.layers.0"; a = f"{p}.self_attn"
emb = G("model.embed_tokens.weight")
anorm = G(f"{p}.input_layernorm.weight")
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
woa = G(f"{a}.o_a_proj.weight") # (16384, 4096) BF16
# Load weights
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
# BF16 references
qa_bf16 = dequant(qa_w, qa_sf, qa_gs.item())
qb_bf16 = dequant(qb_w, qb_sf, qb_gs.item())
kv_bf16 = dequant(kv_w, kv_sf, kv_gs.item())
wob_bf16 = dequant(wob_w, wob_sf, wob_gs.item())
# CuTeDSL runners
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
# Input
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV)
NT = len(token_ids)
with torch.no_grad():
hidden = emb[token_ids]
normed = rms(hidden, anorm, EPS)
print(f" Input: {NT} tokens, amax={normed.amax():.4f}")
# ── Step 1: fused_wqa_wkv ─────────────────────────────────────────
print("\n--- Step 1: fused_wqa_wkv (q_a + kv) ---")
with torch.no_grad():
qa_cute = r_qa.run(normed)
kv_cute = r_kv.run(normed)
qa_ref = normed @ qa_bf16.T
kv_ref = normed @ kv_bf16.T
print(f" q_a CuTeDSL vs BF16: cosine={F.cosine_similarity(qa_cute.flatten().unsqueeze(0).float(), qa_ref.flatten().unsqueeze(0).float()).item():.6f}")
print(f" kv CuTeDSL vs BF16: cosine={F.cosine_similarity(kv_cute.flatten().unsqueeze(0).float(), kv_ref.flatten().unsqueeze(0).float()).item():.6f}")
# ── Step 2: q_norm + kv_norm ──────────────────────────────────────
print("\n--- Step 2: RMS norm (q_a_norm, kv_norm) ---")
with torch.no_grad():
qa_normed = rms(qa_cute, qn, EPS)
kv_normed = rms(kv_cute, kvn, EPS)
print(f" q_a normed: amax={qa_normed.amax():.4f} NaN={torch.isnan(qa_normed).any()}")
print(f" kv normed: amax={kv_normed.amax():.4f} NaN={torch.isnan(kv_normed).any()}")
# ── Step 3: wq_b ──────────────────────────────────────────────────
print("\n--- Step 3: wq_b (q_a → full q) ---")
with torch.no_grad():
q_cute = r_qb.run(qa_normed)
q_ref = qa_normed @ qb_bf16.T
c = F.cosine_similarity(q_cute.flatten().unsqueeze(0).float(), q_ref.flatten().unsqueeze(0).float()).item()
print(f" q_b CuTeDSL vs BF16: cosine={c:.6f}")
print(f" q shape: {q_cute.shape} → ({NT}, {NH}, {HD})")
q_3d = q_cute.view(NT, NH, HD)
print(f" q_3d amax: {q_3d.amax():.4f}")
# ── Step 4: RoPE (reference, GPT-J style) ─────────────────────────
print("\n--- Step 4: RoPE (GPT-J style reference) ---")
cos_sin = build_cos_sin().to(DEV)
positions = torch.arange(NT, dtype=torch.int64, device=DEV)
half_rot = ROPE // 2
cos_q = cos_sin[positions, :half_rot].unsqueeze(1) # (NT, 1, 32)
sin_q = cos_sin[positions, half_rot:].unsqueeze(1)
q_nope = q_3d[:, :, :NOPE].clone()
q_rope = q_3d[:, :, NOPE:].clone()
# GPT-J style: interleave even/odd, not split halves
q_even = q_rope[:, :, 0::2].clone()
q_odd = q_rope[:, :, 1::2].clone()
cos_f = cos_q.to(q_3d.dtype)
sin_f = sin_q.to(q_3d.dtype)
q_even_rot = q_even * cos_f - q_odd * sin_f
q_odd_rot = q_even * sin_f + q_odd * cos_f
q_rope_rot = torch.stack([q_even_rot, q_odd_rot], dim=-1).flatten(-2)
q_with_rope = torch.cat([q_nope, q_rope_rot], dim=-1)
print(f" q with RoPE: amax={q_with_rope.amax():.4f}")
# ── Step 5: Attention (SKIP — use reference) ──────────────────────
print("\n--- Step 5: Attention output ---")
print(" ⚠️ FlashMLA cannot run standalone — using reference implementation")
print(" ⚠️ Running naive scaled dot-product attention in BF16")
# Naive attention: q @ k.T / sqrt(d) @ v
# We need K, V. K comes from kv (after RoPE), V is kv (nope part)
# Actually in MLA, kv is the latent, not full K/V
# For this test, just use the kv latent directly as a proxy
# and do a simplified attention to get a reasonable output
kv_rope = kv_normed # (NT, HD) — latent representation
print(f" kv latent: shape={kv_normed.shape} amax={kv_normed.amax():.4f}")
# Simplified: treat q as (NT, NH, HD) and kv as K=V=(NT, HD)
# This isn't the real MLA attention but gives us a non-garbage output
# to test the post-attention path (wo_a, wo_b)
k_simple = kv_normed.unsqueeze(1).expand(-1, NH, -1) # (NT, NH, HD)
v_simple = kv_normed.unsqueeze(1).expand(-1, NH, -1) # (NT, NH, HD)
scale = HD ** -0.5
attn_weights = torch.matmul(q_with_rope, k_simple.transpose(-1, -2)) * scale
attn_weights = F.softmax(attn_weights.float(), dim=-1).to(torch.bfloat16)
o_ref = torch.matmul(attn_weights, v_simple)
print(f" Naive attention output: amax={o_ref.amax():.4f} NaN={torch.isnan(o_ref).any()}")
# ── Step 6: wo_a (inverse RoPE + BMM) ─────────────────────────────
print("\n--- Step 6: wo_a (inverse RoPE + BMM) ---")
# Inverse RoPE: same as RoPE but sin -> -sin
o_nope = o_ref[:, :, :NOPE].clone()
o_rope = o_ref[:, :, NOPE:].clone()
o_even = o_rope[:, :, 0::2].clone()
o_odd = o_rope[:, :, 1::2].clone()
o_even_inv = o_even * cos_f + o_odd * sin_f
o_odd_inv = -o_even * sin_f + o_odd * cos_f
o_rope_inv = torch.stack([o_even_inv, o_odd_inv], dim=-1).flatten(-2)
o_inv = torch.cat([o_nope, o_rope_inv], dim=-1)
# BMM: (OG, NT, HPG*HD) @ (OG, HPG*HD, OL) → (OG, NT, OL)
o_grouped = o_inv.view(NT, OG, HPG * HD).permute(1, 0, 2)
woa_3d = woa.view(OG, OL, HPG * HD)
z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(NT, OG * OL)
print(f" wo_a z: amax={z.amax():.4f} NaN={torch.isnan(z).any()}")
# ── Step 7: wo_b (CuTeDSL) ────────────────────────────────────────
print("\n--- Step 7: wo_b (CuTeDSL vs BF16) ---")
with torch.no_grad():
wob_cute = r_wob.run(z)
wob_ref = z @ wob_bf16.T
c = F.cosine_similarity(wob_cute.flatten().unsqueeze(0).float(), wob_ref.flatten().unsqueeze(0).float()).item()
print(f" wo_b CuTeDSL vs BF16: cosine={c:.6f} {'' if c >= 0.98 else ''}")
# ── Final: attention output through LM head ───────────────────────
print("\n--- Final: attn output → residual → norm → LM head ---")
fnorm_w = G("model.norm.weight")
lm_head = G("lm_head.weight")
with torch.no_grad():
x = hidden + wob_cute
x_normed = rms(x, fnorm_w, EPS)
logits = x_normed @ lm_head.T
print(f" logits: amax={logits.amax():.4f} NaN={torch.isnan(logits).any()}")
top5 = torch.topk(logits[-1], 5)
print(f" top5 IDs: {top5.indices.tolist()}")
log_std = logits[-1].float().std().item()
print(f" logit std: {log_std:.4f} {'' if 0.5 < log_std < 50 else ''}")
# ── KEY DIAGNOSTIC: What does vLLM's FlashMLA actually do? ────────
print("\n" + "=" * 70)
print(" KEY DIAGNOSTIC: Check FlashMLA availability on B200")
print("=" * 70)
try:
import flash_mla
print(f" flash_mla imported: version={getattr(flash_mla, '__version__', 'unknown')}")
# Check if it supports SM100
cap = torch.cuda.get_device_capability()
print(f" GPU capability: {cap}")
if cap.major >= 10:
print(f" ⚠️ SM{cap.major}{cap.minor} (Blackwell) — FlashMLA may not support this!")
except ImportError:
print(" flash_mla NOT available in this venv (expected — it's in the container)")
# Check what the vLLM MLA attention calls
print("\n vLLM attention path on B200:")
print(" 1. torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")
print(" → C++ CUDA kernel for RoPE + KV cache insert")
print(" 2. self.mla_attn(q, kv, positions, output=out)")
print(" → FlashMLA sparse attention")
print(" Both are compiled CUDA kernels that may NOT work on SM100.")
print(" If either returns garbage, the model outputs EOS immediately.")
if __name__ == "__main__":
main()