Files
nvfp4-megamoe-kernel/tests/unit/test_decode_fmha_layer.py

573 lines
28 KiB
Python

#!/usr/bin/env python3
"""Production FMHA layer comparison test — DECODE phase.
The key difference from test_production_fmha_layer.py:
- That test checks FMHA cos during PREFILL (or with random Q after prefill)
- This test checks FMHA cos during the FIRST DECODE STEP
Why this matters:
During decode, the KV cache has compressed entries (CSA/HCA) + SWA window.
The CSA path uses indexer top-k to select which compressed entries to attend to.
The HCA path gathers ALL compressed entries. The SWA-only path has no compression.
If the per-layer cos is 0.999993 during prefill but drops during decode,
the bug is in the decode-time KV gathering or compressed/SWA parity.
Strategy:
1. Run full production pipeline (single_shot_inference.py forward_layer)
for ALL prefill tokens through layers 0-4, populating KV caches.
2. Run the FIRST decode token through forward_layer, but capture the
production FMHA inputs (q_heads, gathered KV) at each layer.
3. For each layer, ALSO run reference FMHA (dequantize KV to BF16, PyTorch SDPA)
on the SAME gathered KV that the production kernel saw.
4. Compare raw FMHA output (before inverse RoPE, before output projection).
Production values: HD=512, NOPE=448, ROPE=64, H=128, 61 layers, 8 GPUs.
"""
import os, sys, json, math, time
import torch
import torch.nn.functional as F
CHECKPOINT_DIR = os.environ.get(
"CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
NUM_GPUS = int(os.environ.get("NUM_GPUS", "8"))
DEVICE = "cuda:0"
# How many layers to test (first N layers)
TEST_LAYERS = int(os.environ.get("TEST_LAYERS", "5"))
def cosine(a, b):
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
def main():
torch.manual_seed(42)
print("=" * 70)
print("DECODE FMHA LAYER COMPARISON TEST")
print("Tests FMHA accuracy during DECODE (not prefill)")
print("=" * 70)
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
n_layers = cfg["num_hidden_layers"]
H = cfg["hidden_size"]
hd = cfg["head_dim"]
n_h = cfg["num_attention_heads"]
rd = cfg.get("qk_rope_head_dim", 64)
nope_dim = hd - rd
cr = cfg.get("compress_ratios", [128] * n_layers)
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}, nope_dim={nope_dim}")
print(f"Compress ratios (first {TEST_LAYERS}): {cr[:TEST_LAYERS]}")
# Import from single_shot_inference.py
from single_shot_inference import (
load_all_weights, make_nvfp4_linear, get_nvfp4_weight,
rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache,
KVCache, Compressor, Indexer, forward_layer, moe_forward,
_load_moe_weights_stacked, _load_shared_expert_weights,
_cache_layer_weights_no_experts,
)
from dsv4.layers.mhc import mHCLayer, mHCContext
from dsv4.layers.router import Router
from dsv4.layers.moe import Nvfp4MoE
from dsv4.layers.shared_expert import Nvfp4SharedExpert
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
from dsv4.layers.linear import Nvfp4Linear
from dsv4.ops.quantize import (
rmsnorm_quantize_nvfp4, mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4,
quantize_to_nvfp4,
)
print("Loading weights...")
all_w = load_all_weights(CHECKPOINT_DIR)
o_groups = cfg.get("o_groups", 16)
o_rank = cfg.get("o_lora_rank", 1024)
n_ih = cfg.get("index_n_heads", 64)
ihd = cfg.get("index_head_dim", 128)
itk = cfg.get("index_topk", 1024)
rope_caches = {g: build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1)
for g in range(NUM_GPUS)}
# Build all production components
prod_lins, attn_mhcs, ffn_mhcs = {}, {}, {}
attn_norms, ffn_norms = {}, {}
compressors, indexers, kv_caches = {}, {}, {}
routers, moe_runners, se_runners = {}, {}, {}
for li in range(TEST_LAYERS):
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
torch.cuda.set_device(gpu)
pfx = f"model.layers.{li}.self_attn"
mlp_pfx = f"model.layers.{li}.mlp"
ratio = cr[li] if li < len(cr) else 128
# Attention linears
pl = {}
pl['q_a'] = make_nvfp4_linear(H, 1536, dev, all_w, pfx, 'q_a_proj')
pl['q_b'] = make_nvfp4_linear(1536, H * hd, dev, all_w, pfx, 'q_b_proj')
pl['kv'] = make_nvfp4_linear(H, hd, dev, all_w, pfx, 'kv_proj')
hpg = n_h // o_groups
wo_a = Nvfp4GroupedLinear(n_local_groups=o_groups, heads_per_group=hpg,
head_dim=hd, o_lora_rank=o_rank, max_num_tokens=8192, device=dev)
oa_w, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
if oa_w is not None and oa_ws is not None:
wo_a.load_nvfp4_weight(oa_w.to(dev), oa_ws.to(dev),
oa_ws2.to(dev) if oa_ws2 is not None else None,
oa_isc.to(dev) if oa_isc is not None else None)
else:
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
if oa_bf is not None:
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
pl['o_a'] = wo_a; wo_a._use_runtime_gsa = True
pl['o_b'] = make_nvfp4_linear(o_groups * o_rank, H, dev, all_w, pfx, 'o_b_proj')
prod_lins[li] = pl
# mHC
for tag, blocks, fn_s, base_s, scale_s in [
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn",
f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn",
f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
]:
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
if fn is not None and base is not None and scale is not None:
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
n = 4
m.load_weights(
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
W_comb=fn[2*n:].to(dev, torch.float32),
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item())
blocks[li] = m
an_k = f"model.layers.{li}.input_layernorm.weight"
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
max_comp = (8192 + ratio - 1) // ratio if ratio > 0 else 0
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp,
device=dev, indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd)
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
# Router
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{mlp_pfx}.gate.tid2eid" in all_w)
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
top_k=cfg.get("num_experts_per_tok", 6),
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
mode="hash" if is_hash else "dense",
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
if is_hash:
router.load_weights(hash_lut=all_w[f"{mlp_pfx}.gate.tid2eid"].to(dev, torch.int32))
else:
eb = all_w.get(f"{mlp_pfx}.gate.e_score_correction_bias")
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, mlp_pfx, 'gate')
E = cfg["n_routed_experts"]
if gate_w is not None and gate_ws is not None:
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_lin.fp4 = [gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)]
gate_lin.sf = [gate_ws.to(dev)]
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
gate_lin.gs = [1.0]
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = isc_v
gate_lin._use_runtime_gsa = True
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
else:
gw = all_w.get(f"{mlp_pfx}.gate.weight")
if gw is not None:
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
g_bf16 = g_bf16.bfloat16().to(dev)
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_lin.fp4 = [g_fp4]
gate_lin.sf = [g_sf]
gate_lin.gs = [g_gs]
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
gate_lin._use_runtime_gsa = True
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.finalize_weights(); routers[li] = router
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
intermediate_size=cfg.get("moe_intermediate_size", 3072),
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)); moe.set_fused_swiglu(True)
_load_moe_weights_stacked(all_w, li, mlp_pfx, dev, moe, cfg)
moe._ensure_stacked(); moe._use_runtime_gsa = True; moe_runners[li] = moe
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
se.set_fused_swiglu(True)
_load_shared_expert_weights(all_w, li, mlp_pfx, dev, se, cfg)
se._ensure_initialized(); se._use_runtime_gsa = True; se_runners[li] = se
torch.cuda.empty_cache()
for li in range(TEST_LAYERS):
pfx = f"model.layers.{li}.self_attn.compressor"
dev = f"cuda:{li % NUM_GPUS}"
if li in compressors: compressors[li].load(all_w, pfx, dev=dev)
if li in indexers: indexers[li].load(all_w, f"{pfx}.indexer", dev=dev)
print("Components built")
# Embedding + tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
bos = tokenizer.bos_token_id or 0
USER_TOKEN, ASSISTANT_TOKEN, THINK_START = 128803, 128804, 128821
input_ids = [bos, USER_TOKEN]
input_ids += tokenizer.encode('\n\nThe capital of France is', add_special_tokens=False)
input_ids.append(ASSISTANT_TOKEN)
input_ids.append(THINK_START)
print(f"Input: {len(input_ids)} tokens: {input_ids}")
torch.cuda.set_device(0)
embed_w = all_w.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE))
devs_list = [f"cuda:{g}" for g in range(NUM_GPUS)]
layer_w = _cache_layer_weights_no_experts(all_w, TEST_LAYERS, devs_list)
del all_w; import gc; gc.collect()
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
torch.cuda.set_device(0)
# ================================================================
# PHASE 1: Run full production pipeline to populate KV caches
# ================================================================
print(f"\n{'='*70}")
print("PHASE 1: Populating KV caches (prefill)")
print(f"{'='*70}")
for pi, tid_val in enumerate(input_ids):
t1 = time.time()
tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE)
pos = torch.tensor([pi], dtype=torch.long, device=DEVICE)
tid32 = torch.tensor([tid_val], dtype=torch.int32, device=DEVICE)
X = mHCLayer.init_state(embed(tid))
for li in range(TEST_LAYERS):
gpu = li % NUM_GPUS
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
torch.cuda.set_device(gpu)
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li), _use_fused_rmsnorm_quantize=True)
if pi % 5 == 0:
print(f" Token {pi}/{len(input_ids)}: {time.time()-t1:.2f}s", flush=True)
# Print KV cache state after prefill
print(f"\nKV cache state after prefill ({len(input_ids)} tokens):")
for li in range(TEST_LAYERS):
kc = kv_caches[li]
ratio = cr[li] if li < len(cr) else 128
print(f" L{li} (ratio={ratio}): n_comp={kc.n_comp} swa_len={kc.swa_len} "
f"total_KV={kc.n_comp + kc.swa_len}")
# ================================================================
# PHASE 2: Run ONE decode step, capturing FMHA inputs/outputs
# ================================================================
print(f"\n{'='*70}")
print("PHASE 2: Decode FMHA comparison per layer")
print(f"{'='*70}")
# Use a real next token — the model's own greedy output would require
# a full forward pass to get logits. Instead, use a reasonable continuation
# token. For "The capital of France is" → the space token or a letter.
# Actually, we need to run the FULL decode forward pass (all layers) to get
# the actual Q at each layer. So we'll intercept inside forward_attention.
#
# Approach: duplicate the forward_attention logic, capturing FMHA inputs
# at each layer, then compare against reference SDPA.
# First, we need the hidden state X at the decode position.
# We'll re-run the decode step manually, layer by layer, capturing
# the production FMHA inputs and comparing against reference.
# Decode token: use the actual next position
decode_pos = len(input_ids)
# Use a common token — the " " (space) token
decode_tid = tokenizer.encode(" the", add_special_tokens=False)
if len(decode_tid) > 0:
decode_tid = decode_tid[0]
else:
decode_tid = tokenizer.convert_tokens_to_ids(" ")
print(f"Decode token: id={decode_tid} pos={decode_pos}")
# Get initial hidden state from embedding
dec_tid = torch.tensor([decode_tid], dtype=torch.long, device=DEVICE)
dec_tid32 = torch.tensor([decode_tid], dtype=torch.int32, device=DEVICE)
dec_pos = torch.tensor([decode_pos], dtype=torch.long, device=DEVICE)
X = mHCLayer.init_state(embed(dec_tid))
print(f"Initial X: shape={tuple(X.shape)} |X|={X.abs().max().item():.4f}")
results = {}
for li in range(TEST_LAYERS):
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
torch.cuda.set_device(gpu)
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(dev)
ratio = cr[li] if li < len(cr) else 128
kc = kv_caches[li]
pfx = f"model.layers.{li}.self_attn"
scale = 1.0 / math.sqrt(hd)
# ---- mHC pre_block + rmsnorm (same as forward_layer) ----
attn_mhc = attn_mhcs.get(li)
ffn_mhc = ffn_mhcs.get(li)
attn_norm_w = attn_norms.get(li)
ffn_norm_w = ffn_norms.get(li)
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X)
ctx_a = mHCContext(B_l=B_l_a, C_l=C_l_a)
# Fused mHC + rmsnorm + NVFP4 quantize (production path)
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
X, A_l_a, attn_norm_w.to(dev, torch.float32))
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
# ---- Manually replicate forward_attention to capture FMHA inputs ----
T = x_normed.shape[0]
pl = prod_lins[li]
# 1. Q projection
q_a = pl['q_a'].run_from_quantized(x_quant_attn)
q_norm_w = layer_w[li].get(f"{pfx}.q_a_norm.weight")
if q_norm_w is not None:
q_a_quant = rmsnorm_quantize_nvfp4(q_a, q_norm_w.to(dev, torch.float32))
q_a = dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
q = pl['q_b'].run_from_quantized(q_a_quant)
else:
q = pl['q_b'](q_a)
q = unweighted_rmsnorm(q).bfloat16()
q_heads = q.reshape(T, n_h, hd)
q_heads = _apply_rope(q_heads, dec_pos.to(dev), *rope_caches[gpu][:2], rd)
# 2. KV projection + cache
kv = pl['kv'].run_from_quantized(x_quant_attn)
kv_norm_w = layer_w[li].get(f"{pfx}.kv_norm.weight")
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
kv_3d = kv.reshape(T, 1, hd)
kv_3d = _apply_rope(kv_3d, dec_pos.to(dev), *rope_caches[gpu][:2], rd)
kv_roped = kv_3d.reshape(T, hd)
kc.append_swa(kv_roped, dec_pos.to(dev))
# 3. Compressor → compressed KV
compressor = compressors.get(li)
indexer = indexers.get(li)
comp_pos, block_bias = None, None
if compressor is not None and compressor.ratio > 0:
comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, dec_pos.to(dev))
if comp_kv_fp32 is not None:
from dsv4.kernels.cuda.loader import get_cuda_module
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
nope_fp32 = comp_kv_fp32[:, :nope_dim].contiguous()
rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous()
rope_3d = rope_bf16.unsqueeze(1)
rope_3d = _apply_rope(rope_3d, comp_pos, *rope_caches[gpu][:2], rd)
rope_bf16 = rope_3d.squeeze(1)
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
kc.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, dec_pos.to(dev))
kc.set_indexer_keys_fp8(comp_idx_kv)
# 4. Indexer top-k (CSA layers)
topk_idx = None
if indexer is not None and ratio == 4:
topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos.to(dev), layer_idx=li)
if topk_idx is not None:
print(f" L{li} CSA: indexer topk shape={tuple(topk_idx.shape)} "
f"range=[{topk_idx.min().item()}, {topk_idx.max().item()}] "
f"n_comp={kc.n_comp}", flush=True)
# 5. Gather KV — same logic as forward_attention
swa_kv, _swa_pos = kc.get_swa()
swa_len = swa_kv.shape[0]
if kc.n_comp > 0:
if ratio == 4:
# CSA: gather top-k compressed rows
assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k"
tk = topk_idx[0].clamp(0, kc.n_comp - 1).int()
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_selective(tk)
gather_mode = f"CSA top-k ({tk.numel()} comp + {swa_len} SWA)"
elif ratio > 4:
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_all()
gather_mode = f"HCA all ({kc.n_comp} comp + {swa_len} SWA)"
else:
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_swa_only()
gather_mode = f"SWA-only ({swa_len} SWA)"
else:
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_swa_only()
gather_mode = f"SWA-only ({swa_len} SWA)"
seq_len = kv_nope_scale.shape[0]
if seq_len == 0:
print(f" L{li}: SKIPPED (seq_len=0)")
continue
print(f" L{li}: {gather_mode} → seq_len={seq_len}", flush=True)
# 6. Run production mixed FP8 FMHA
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
q_4d = q_heads.permute(1, 0, 2).unsqueeze(0).contiguous() # (1, n_h, T, hd)
sinks = layer_w[li].get(f"{pfx}.sinks")
sink_bias = None
if sinks is not None:
sink_bias = sinks.to(device=dev).float().reshape(n_h)
try:
o_prod_4d, lse_prod = fmha_mixed_fp8_decode_raw(
q_4d, kv_nope_fp8, kv_nope_scale, kv_rope_bf16,
scale, attn_sink=sink_bias, rope_dim=rd)
except Exception as e:
print(f" L{li}: PROD FMHA FAILED: {e}")
results[li] = {'cos': -1.0, 'error': str(e)}
continue
o_prod = o_prod_4d.squeeze(0) # (n_h, T, hd)
# 7. Reference: dequantize mixed KV to BF16, run reference with sink bias
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd)
k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd)
v_4d = k_4d.clone()
if sink_bias is not None:
# DSV4 sink is denominator-only: O = sum(P*V) / (sum(P) + exp(sb))
# where P = softmax(QK*scale). The sink has NO V contribution.
# Reference: compute O_no_sink, then scale by correction factor.
q_ref = q_4d.float() # (1, H, T, hd)
k_ref = k_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
v_ref = v_4d.float().expand(1, n_h, -1, -1) # (1, H, N, hd)
scores = torch.matmul(q_ref, k_ref.transpose(-2, -1)) * scale # (1, H, T, N)
# O_no_sink = softmax(scores) @ V
O_no_sink = F.softmax(scores, dim=-1) @ v_ref # (1, H, T, hd)
# Correction: O_with_sink = O_no_sink * Z / (Z + exp(sb))
# Z = sum(exp(scores - max)) per head, but more conveniently:
# Z / (Z + exp(sb)) = 1 / (1 + exp(sb) / Z) = 1 / (1 + exp(sb - log(Z)))
# log(Z) = logsumexp(scores)
lse = torch.logsumexp(scores, dim=-1, keepdim=True) # (1, H, T, 1)
# sb shape: (n_h,) → (1, n_h, 1, 1)
sb_4d = sink_bias.reshape(1, n_h, 1, 1)
correction = 1.0 / (1.0 + torch.exp(sb_4d - lse))
o_ref_4d = (O_no_sink * correction).bfloat16() # (1, H, T, hd)
else:
o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd)
o_ref = o_ref_4d.squeeze(0) # (n_h, T, hd)
# 8. Compare
cos_val = cosine(o_prod, o_ref)
mag_prod = o_prod.float().abs().max().item()
mag_ref = o_ref.float().abs().max().item()
# Per-head cosine AND magnitude ratio
o_prod_h = o_prod.float().squeeze(1) # (n_h, hd)
o_ref_h = o_ref.float().squeeze(1)
per_head_cos = F.cosine_similarity(o_prod_h, o_ref_h, dim=-1)
per_head_mag_prod = o_prod_h.abs().max(dim=-1).values # (n_h,)
per_head_mag_ref = o_ref_h.abs().max(dim=-1).values # (n_h,)
per_head_mag_ratio = (per_head_mag_prod / (per_head_mag_ref + 1e-8)) # (n_h,)
min_head = per_head_cos.min().item()
mean_head = per_head_cos.mean().item()
worst_heads = per_head_cos.argsort()[:5]
# Find heads with worst magnitude ratio
worst_mag = per_head_mag_ratio.sub(1.0).abs().argsort(descending=True)[:5]
results[li] = {
'cos': cos_val, 'mag_prod': mag_prod, 'mag_ref': mag_ref,
'seq_len': seq_len, 'ratio': ratio, 'gather_mode': gather_mode,
'n_comp': kc.n_comp, 'swa_len': swa_len,
'min_head_cos': min_head, 'mean_head_cos': mean_head,
}
status = "PASS" if cos_val >= 0.999 else "FAIL"
print(f" L{li}: {status} cos={cos_val:.6f} min_head={min_head:.6f} mean_head={mean_head:.6f} "
f"|prod|={mag_prod:.4f} |ref|={mag_ref:.4f} seq={seq_len} {gather_mode}", flush=True)
if cos_val < 0.999:
cos_list = [f'{c:.4f}' for c in per_head_cos[worst_heads].tolist()]
mag_list = [f'{r:.4f}' for r in per_head_mag_ratio[worst_mag].tolist()]
print(f" Worst heads (cos): {worst_heads.tolist()} cos={cos_list}")
print(f" Worst heads (mag): {worst_mag.tolist()} ratio={mag_list}")
print(f" Mag ratio range: [{per_head_mag_ratio.min().item():.4f}, {per_head_mag_ratio.max().item():.4f}]")
# ---- Continue through the rest of the layer (so subsequent layers get correct X) ----
# Apply inverse RoPE to production output
attn_out = o_prod.permute(1, 0, 2) # (T, n_h, hd)
attn_out = _apply_rope(attn_out, dec_pos.to(dev), *rope_caches[gpu][:2], rd, inverse=True)
# Output projection
wo_a_lin = pl.get('o_a')
if wo_a_lin is not None:
g_3d = wo_a_lin.run(attn_out)
g_flat = g_3d.reshape(T, -1)
F_attn = pl['o_b'](g_flat)
else:
hpg_fb = n_h // o_groups; gid_fb = hpg_fb * hd
oa_full = layer_w[li].get(f"{pfx}.o_a_proj.weight")
if oa_full is not None:
oa_bf = oa_full.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd)
a_grp = a_flat.reshape(T, o_groups, gid_fb); oa_3d = oa_bf.reshape(o_groups, o_rank, gid_fb)
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
F_attn = pl['o_b'](g_flat)
else:
F_attn = torch.zeros(T, H, dtype=torch.bfloat16, device=dev)
# mHC post_block
X_mid = attn_mhc.post_block(X, F_attn, ctx_a)
# FFN mHC + MoE
A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid)
ctx_f = mHCContext(B_l=B_l_f, C_l=C_l_f)
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
X_mid, A_l_f, ffn_norm_w.to(dev, torch.float32))
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
F_ffn = moe_forward(x_ffn, li, moe_runners.get(li), se_runners.get(li),
routers.get(li), dec_tid32.to(dev))
X = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
# ================================================================
# Summary
# ================================================================
print(f"\n{'='*70}")
print("DECODE FMHA COMPARISON SUMMARY")
print(f"{'='*70}")
all_pass = True
for li in sorted(results.keys()):
r = results[li]
c = r.get('cos', -1.0)
status = "PASS" if c >= 0.999 else "FAIL"
if c < 0.999: all_pass = False
print(f" L{li}: {status} cos={c:.6f} seq={r.get('seq_len','?')} "
f"mode={r.get('gather_mode','?')} "
f"n_comp={r.get('n_comp','?')} swa={r.get('swa_len','?')}")
print()
if all_pass:
print("ALL DECODE LAYERS PASSED (cos >= 0.999)")
else:
print("SOME DECODE LAYERS FAILED — investigate KV gathering or compressed/SWA parity")
print()
print("If prefill cos was 0.999993 but decode cos < 0.999:")
print(" → Bug is in decode-time KV gathering or compressed/SWA parity")
print(" → Check: gather_mixed_selective (CSA), gather_mixed_all (HCA)")
print(" → Check: SWA positions vs compressed positions (causality)")
print(" → Check: indexer top-k indices validity")
return 0 if all_pass else 1
if __name__ == "__main__":
sys.exit(main())