Files
nvfp4-megamoe-kernel/tests/unit/test_production_fmha_layer.py

349 lines
16 KiB
Python

#!/usr/bin/env python3
"""Production FMHA layer comparison test — real model weights, real pipeline.
Strategy:
1. Run the full production pipeline (single_shot_inference.py forward_layer)
for all prefill tokens through layers 0-4.
2. On the LAST prefill token, for each layer, ALSO run the reference FMHA
(dequantize KV to BF16, run PyTorch SDPA) on the SAME gathered KV
that the production kernel saw.
3. Compare raw FMHA output (before inverse RoPE, before output projection).
This isolates the FMHA kernel's accuracy from the rest of the pipeline.
Production values: HD=512, NOPE=448, ROPE=64, H=128, 61 layers, 8 GPUs.
"""
import os, sys, json, math, time
import torch
import torch.nn.functional as F
CHECKPOINT_DIR = os.environ.get(
"CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
NUM_GPUS = int(os.environ.get("NUM_GPUS", "8"))
DEVICE = "cuda:0"
def cosine(a, b):
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
def main():
torch.manual_seed(42)
print("=" * 70)
print("PRODUCTION FMHA LAYER COMPARISON TEST")
print("=" * 70)
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
n_layers = cfg["num_hidden_layers"]
H = cfg["hidden_size"]
hd = cfg["head_dim"]
n_h = cfg["num_attention_heads"]
rd = cfg.get("qk_rope_head_dim", 64)
nope_dim = hd - rd
cr = cfg.get("compress_ratios", [128] * n_layers)
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
from single_shot_inference import (
load_all_weights, make_nvfp4_linear, get_nvfp4_weight,
rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache,
KVCache, Compressor, Indexer, forward_layer, moe_forward,
_load_moe_weights_stacked, _load_shared_expert_weights,
_cache_layer_weights_no_experts,
)
from dsv4.layers.mhc import mHCLayer, mHCContext
from dsv4.layers.router import Router
from dsv4.layers.moe import Nvfp4MoE
from dsv4.layers.shared_expert import Nvfp4SharedExpert
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
from dsv4.layers.linear import Nvfp4Linear
from dsv4.ops.quantize import (
rmsnorm_quantize_nvfp4, mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4,
quantize_to_nvfp4,
)
print("Loading weights...")
all_w = load_all_weights(CHECKPOINT_DIR)
TEST_LAYERS = 5
o_groups = cfg.get("o_groups", 16)
o_rank = cfg.get("o_lora_rank", 1024)
n_ih = cfg.get("index_n_heads", 64)
ihd = cfg.get("index_head_dim", 128)
itk = cfg.get("index_topk", 1024)
rope_caches = {g: build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1)
for g in range(NUM_GPUS)}
# Build all production components (same as single_shot main())
prod_lins, attn_mhcs, ffn_mhcs = {}, {}, {}
attn_norms, ffn_norms = {}, {}
compressors, indexers, kv_caches = {}, {}, {}
routers, moe_runners, se_runners = {}, {}, {}
for li in range(TEST_LAYERS):
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
torch.cuda.set_device(gpu)
pfx = f"model.layers.{li}.self_attn"
mlp_pfx = f"model.layers.{li}.mlp"
ratio = cr[li] if li < len(cr) else 128
# Attention linears
pl = {}
pl['q_a'] = make_nvfp4_linear(H, 1536, dev, all_w, pfx, 'q_a_proj')
pl['q_b'] = make_nvfp4_linear(1536, H * hd, dev, all_w, pfx, 'q_b_proj')
pl['kv'] = make_nvfp4_linear(H, hd, dev, all_w, pfx, 'kv_proj')
hpg = n_h // o_groups
wo_a = Nvfp4GroupedLinear(n_local_groups=o_groups, heads_per_group=hpg,
head_dim=hd, o_lora_rank=o_rank, max_num_tokens=8192, device=dev)
oa_w, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
if oa_w is not None and oa_ws is not None:
wo_a.load_nvfp4_weight(oa_w.to(dev), oa_ws.to(dev),
oa_ws2.to(dev) if oa_ws2 is not None else None,
oa_isc.to(dev) if oa_isc is not None else None)
else:
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
if oa_bf is not None:
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
pl['o_a'] = wo_a; wo_a._use_runtime_gsa = True
pl['o_b'] = make_nvfp4_linear(o_groups * o_rank, H, dev, all_w, pfx, 'o_b_proj')
prod_lins[li] = pl
# mHC
for tag, blocks, fn_s, base_s, scale_s in [
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn",
f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn",
f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
]:
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
if fn is not None and base is not None and scale is not None:
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
n = 4
m.load_weights(
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
W_comb=fn[2*n:].to(dev, torch.float32),
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item())
blocks[li] = m
an_k = f"model.layers.{li}.input_layernorm.weight"
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
max_comp = (8192 + ratio - 1) // ratio if ratio > 0 else 0
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp,
device=dev, indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd)
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
# Router
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{mlp_pfx}.gate.tid2eid" in all_w)
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
top_k=cfg.get("num_experts_per_tok", 6),
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
mode="hash" if is_hash else "dense",
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
if is_hash:
router.load_weights(hash_lut=all_w[f"{mlp_pfx}.gate.tid2eid"].to(dev, torch.int32))
else:
eb = all_w.get(f"{mlp_pfx}.gate.e_score_correction_bias")
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, mlp_pfx, 'gate')
E = cfg["n_routed_experts"]
if gate_w is not None and gate_ws is not None:
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_lin.fp4 = [gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)]
gate_lin.sf = [gate_ws.to(dev)]
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
gate_lin.gs = [1.0]
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = isc_v
gate_lin._use_runtime_gsa = True
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
else:
# BF16 gate weight — quantize to NVFP4
gw = all_w.get(f"{mlp_pfx}.gate.weight")
if gw is not None:
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
g_bf16 = g_bf16.bfloat16().to(dev)
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_lin.fp4 = [g_fp4]
gate_lin.sf = [g_sf]
gate_lin.gs = [g_gs]
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
gate_lin._use_runtime_gsa = True
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.finalize_weights(); routers[li] = router
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
intermediate_size=cfg.get("moe_intermediate_size", 3072),
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)); moe.set_fused_swiglu(True)
_load_moe_weights_stacked(all_w, li, mlp_pfx, dev, moe, cfg)
moe._ensure_stacked(); moe._use_runtime_gsa = True; moe_runners[li] = moe
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
se.set_fused_swiglu(True)
_load_shared_expert_weights(all_w, li, mlp_pfx, dev, se, cfg)
se._ensure_initialized(); se._use_runtime_gsa = True; se_runners[li] = se
torch.cuda.empty_cache()
for li in range(TEST_LAYERS):
pfx = f"model.layers.{li}.self_attn.compressor"
dev = f"cuda:{li % NUM_GPUS}"
if li in compressors: compressors[li].load(all_w, pfx, dev=dev)
if li in indexers: indexers[li].load(all_w, f"{pfx}.indexer", dev=dev)
print("Components built")
# Embedding + tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
bos = tokenizer.bos_token_id or 0
USER_TOKEN, ASSISTANT_TOKEN, THINK_START = 128803, 128804, 128821
input_ids = [bos, USER_TOKEN]
input_ids += tokenizer.encode('\n\nThe capital of France is', add_special_tokens=False)
input_ids.append(ASSISTANT_TOKEN); input_ids.append(THINK_START)
print(f"Input: {len(input_ids)} tokens")
torch.cuda.set_device(0)
embed_w = all_w.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE))
devs_list = [f"cuda:{g}" for g in range(NUM_GPUS)]
layer_w = _cache_layer_weights_no_experts(all_w, TEST_LAYERS, devs_list)
del all_w; import gc; gc.collect()
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
torch.cuda.set_device(0)
# ================================================================
# PHASE 1: Run full production pipeline to populate KV caches
# ================================================================
print(f"\nPhase 1: Populating KV caches...")
for pi, tid_val in enumerate(input_ids):
t1 = time.time()
tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE)
pos = torch.tensor([pi], dtype=torch.long, device=DEVICE)
tid32 = torch.tensor([tid_val], dtype=torch.int32, device=DEVICE)
X = mHCLayer.init_state(embed(tid))
for li in range(TEST_LAYERS):
gpu = li % NUM_GPUS
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
torch.cuda.set_device(gpu)
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li), _use_fused_rmsnorm_quantize=True)
if pi % 5 == 0:
print(f" Token {pi}/{len(input_ids)}: {time.time()-t1:.2f}s", flush=True)
# ================================================================
# PHASE 2: For each layer, gather KV, run production FMHA, compare vs SDPA
# ================================================================
print(f"\nPhase 2: FMHA comparison per layer...")
results = {}
for li in range(TEST_LAYERS):
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
torch.cuda.set_device(gpu)
ratio = cr[li] if li < len(cr) else 128
k_cache = kv_caches[li]
# Gather KV in mixed format (same as production path)
if k_cache.n_comp > 0:
if ratio > 4:
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = k_cache.gather_mixed_all()
else:
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = k_cache.gather_mixed_swa_only()
else:
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = k_cache.gather_mixed_swa_only()
seq_len = kv_nope_scale.shape[0]
if seq_len == 0:
print(f" L{li}: SKIPPED (seq_len=0)")
continue
# Generate a test Q (random, on this GPU)
q_bf16 = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device=dev) * 0.5
# 1. Run production mixed FP8 FMHA
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
scale_val = 1.0 / math.sqrt(hd)
try:
o_prod, lse_prod = fmha_mixed_fp8_decode_raw(
q_bf16, kv_nope_fp8, kv_nope_scale, kv_rope_bf16, scale_val, rope_dim=rd)
except Exception as e:
print(f" L{li}: PROD FMHA FAILED: {e}")
results[li] = {'cos': -1.0, 'error': str(e)}
continue
# 2. Reference: dequantize KV, run SDPA
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd)
k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd)
v_4d = k_4d.clone()
o_ref = F.scaled_dot_product_attention(q_bf16, k_4d, v_4d, scale=scale_val) # (1, H, 1, hd)
# 3. Compare
cos_val = cosine(o_prod, o_ref)
mag_prod = o_prod.float().abs().max().item()
mag_ref = o_ref.float().abs().max().item()
# Per-head cosine
o_prod_h = o_prod.float().squeeze(2) # (1, H, hd) → (H, hd) after squeeze
o_ref_h = o_ref.float().squeeze(2)
if o_prod_h.dim() == 3: o_prod_h = o_prod_h.squeeze(0)
if o_ref_h.dim() == 3: o_ref_h = o_ref_h.squeeze(0)
per_head_cos = F.cosine_similarity(o_prod_h, o_ref_h, dim=-1)
min_head = per_head_cos.min().item()
mean_head = per_head_cos.mean().item()
results[li] = {
'cos': cos_val, 'mag_prod': mag_prod, 'mag_ref': mag_ref,
'seq_len': seq_len, 'ratio': ratio,
'min_head_cos': min_head, 'mean_head_cos': mean_head,
}
status = "PASS" if cos_val >= 0.999 else "FAIL"
print(f" L{li}: {status} cos={cos_val:.6f} min_head={min_head:.6f} mean_head={mean_head:.6f} "
f"|prod|={mag_prod:.4f} |ref|={mag_ref:.4f} seq_len={seq_len} ratio={ratio}")
if cos_val < 0.999:
worst = per_head_cos.argsort()[:5]
print(f" Worst heads: {worst.tolist()} cos={[f'{c:.4f}' for c in per_head_cos[worst].tolist()]}")
# Summary
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
all_pass = True
for li in sorted(results.keys()):
r = results[li]
c = r.get('cos', -1.0)
status = "PASS" if c >= 0.999 else "FAIL"
if c < 0.999: all_pass = False
print(f" L{li}: {status} cos={c:.6f} seq={r.get('seq_len','?')} ratio={r.get('ratio','?')}")
print()
if all_pass:
print("ALL PASSED (cos >= 0.999)")
else:
print("SOME FAILED — see per-layer output above")
return 0 if all_pass else 1
if __name__ == "__main__":
sys.exit(main())