test: decode FMHA layer comparison — checks FMHA accuracy during decode step
This commit is contained in:
543
tests/unit/test_decode_fmha_layer.py
Normal file
543
tests/unit/test_decode_fmha_layer.py
Normal file
@@ -0,0 +1,543 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Production FMHA layer comparison test — DECODE phase.
|
||||
|
||||
The key difference from test_production_fmha_layer.py:
|
||||
- That test checks FMHA cos during PREFILL (or with random Q after prefill)
|
||||
- This test checks FMHA cos during the FIRST DECODE STEP
|
||||
|
||||
Why this matters:
|
||||
During decode, the KV cache has compressed entries (CSA/HCA) + SWA window.
|
||||
The CSA path uses indexer top-k to select which compressed entries to attend to.
|
||||
The HCA path gathers ALL compressed entries. The SWA-only path has no compression.
|
||||
If the per-layer cos is 0.999993 during prefill but drops during decode,
|
||||
the bug is in the decode-time KV gathering or compressed/SWA parity.
|
||||
|
||||
Strategy:
|
||||
1. Run full production pipeline (single_shot_inference.py forward_layer)
|
||||
for ALL prefill tokens through layers 0-4, populating KV caches.
|
||||
2. Run the FIRST decode token through forward_layer, but capture the
|
||||
production FMHA inputs (q_heads, gathered KV) at each layer.
|
||||
3. For each layer, ALSO run reference FMHA (dequantize KV to BF16, PyTorch SDPA)
|
||||
on the SAME gathered KV that the production kernel saw.
|
||||
4. Compare raw FMHA output (before inverse RoPE, before output projection).
|
||||
|
||||
Production values: HD=512, NOPE=448, ROPE=64, H=128, 61 layers, 8 GPUs.
|
||||
"""
|
||||
import os, sys, json, math, time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
CHECKPOINT_DIR = os.environ.get(
|
||||
"CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||||
NUM_GPUS = int(os.environ.get("NUM_GPUS", "8"))
|
||||
DEVICE = "cuda:0"
|
||||
# How many layers to test (first N layers)
|
||||
TEST_LAYERS = int(os.environ.get("TEST_LAYERS", "5"))
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
||||
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
print("=" * 70)
|
||||
print("DECODE FMHA LAYER COMPARISON TEST")
|
||||
print("Tests FMHA accuracy during DECODE (not prefill)")
|
||||
print("=" * 70)
|
||||
|
||||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||||
cfg = json.load(f)
|
||||
n_layers = cfg["num_hidden_layers"]
|
||||
H = cfg["hidden_size"]
|
||||
hd = cfg["head_dim"]
|
||||
n_h = cfg["num_attention_heads"]
|
||||
rd = cfg.get("qk_rope_head_dim", 64)
|
||||
nope_dim = hd - rd
|
||||
cr = cfg.get("compress_ratios", [128] * n_layers)
|
||||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}, nope_dim={nope_dim}")
|
||||
print(f"Compress ratios (first {TEST_LAYERS}): {cr[:TEST_LAYERS]}")
|
||||
|
||||
# Import from single_shot_inference.py
|
||||
from single_shot_inference import (
|
||||
load_all_weights, make_nvfp4_linear, get_nvfp4_weight,
|
||||
rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache,
|
||||
KVCache, Compressor, Indexer, forward_layer, moe_forward,
|
||||
_load_moe_weights_stacked, _load_shared_expert_weights,
|
||||
_cache_layer_weights_no_experts,
|
||||
)
|
||||
from dsv4.layers.mhc import mHCLayer, mHCContext
|
||||
from dsv4.layers.router import Router
|
||||
from dsv4.layers.moe import Nvfp4MoE
|
||||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||||
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
from dsv4.ops.quantize import (
|
||||
rmsnorm_quantize_nvfp4, mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4,
|
||||
quantize_to_nvfp4,
|
||||
)
|
||||
|
||||
print("Loading weights...")
|
||||
all_w = load_all_weights(CHECKPOINT_DIR)
|
||||
|
||||
o_groups = cfg.get("o_groups", 16)
|
||||
o_rank = cfg.get("o_lora_rank", 1024)
|
||||
n_ih = cfg.get("index_n_heads", 64)
|
||||
ihd = cfg.get("index_head_dim", 128)
|
||||
itk = cfg.get("index_topk", 1024)
|
||||
|
||||
rope_caches = {g: build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1)
|
||||
for g in range(NUM_GPUS)}
|
||||
|
||||
# Build all production components
|
||||
prod_lins, attn_mhcs, ffn_mhcs = {}, {}, {}
|
||||
attn_norms, ffn_norms = {}, {}
|
||||
compressors, indexers, kv_caches = {}, {}, {}
|
||||
routers, moe_runners, se_runners = {}, {}, {}
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
torch.cuda.set_device(gpu)
|
||||
pfx = f"model.layers.{li}.self_attn"
|
||||
mlp_pfx = f"model.layers.{li}.mlp"
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
|
||||
# Attention linears
|
||||
pl = {}
|
||||
pl['q_a'] = make_nvfp4_linear(H, 1536, dev, all_w, pfx, 'q_a_proj')
|
||||
pl['q_b'] = make_nvfp4_linear(1536, H * hd, dev, all_w, pfx, 'q_b_proj')
|
||||
pl['kv'] = make_nvfp4_linear(H, hd, dev, all_w, pfx, 'kv_proj')
|
||||
hpg = n_h // o_groups
|
||||
wo_a = Nvfp4GroupedLinear(n_local_groups=o_groups, heads_per_group=hpg,
|
||||
head_dim=hd, o_lora_rank=o_rank, max_num_tokens=8192, device=dev)
|
||||
oa_w, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
|
||||
if oa_w is not None and oa_ws is not None:
|
||||
wo_a.load_nvfp4_weight(oa_w.to(dev), oa_ws.to(dev),
|
||||
oa_ws2.to(dev) if oa_ws2 is not None else None,
|
||||
oa_isc.to(dev) if oa_isc is not None else None)
|
||||
else:
|
||||
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_bf is not None:
|
||||
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
|
||||
pl['o_a'] = wo_a; wo_a._use_runtime_gsa = True
|
||||
pl['o_b'] = make_nvfp4_linear(o_groups * o_rank, H, dev, all_w, pfx, 'o_b_proj')
|
||||
prod_lins[li] = pl
|
||||
|
||||
# mHC
|
||||
for tag, blocks, fn_s, base_s, scale_s in [
|
||||
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn",
|
||||
f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
|
||||
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn",
|
||||
f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
|
||||
]:
|
||||
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
|
||||
if fn is not None and base is not None and scale is not None:
|
||||
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
|
||||
n = 4
|
||||
m.load_weights(
|
||||
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
|
||||
W_comb=fn[2*n:].to(dev, torch.float32),
|
||||
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
|
||||
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
|
||||
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
|
||||
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item())
|
||||
blocks[li] = m
|
||||
|
||||
an_k = f"model.layers.{li}.input_layernorm.weight"
|
||||
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
|
||||
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
|
||||
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
|
||||
|
||||
max_comp = (8192 + ratio - 1) // ratio if ratio > 0 else 0
|
||||
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp,
|
||||
device=dev, indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd)
|
||||
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
||||
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
||||
|
||||
# Router
|
||||
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{mlp_pfx}.gate.tid2eid" in all_w)
|
||||
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
|
||||
top_k=cfg.get("num_experts_per_tok", 6),
|
||||
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
|
||||
mode="hash" if is_hash else "dense",
|
||||
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
|
||||
if is_hash:
|
||||
router.load_weights(hash_lut=all_w[f"{mlp_pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||||
else:
|
||||
eb = all_w.get(f"{mlp_pfx}.gate.e_score_correction_bias")
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, mlp_pfx, 'gate')
|
||||
E = cfg["n_routed_experts"]
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)]
|
||||
gate_lin.sf = [gate_ws.to(dev)]
|
||||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gate_lin.gs = [1.0]
|
||||
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = isc_v
|
||||
gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
else:
|
||||
gw = all_w.get(f"{mlp_pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||||
g_bf16 = g_bf16.bfloat16().to(dev)
|
||||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
||||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||||
gate_lin.fp4 = [g_fp4]
|
||||
gate_lin.sf = [g_sf]
|
||||
gate_lin.gs = [g_gs]
|
||||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
gate_lin._use_runtime_gsa = True
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
|
||||
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)); moe.set_fused_swiglu(True)
|
||||
_load_moe_weights_stacked(all_w, li, mlp_pfx, dev, moe, cfg)
|
||||
moe._ensure_stacked(); moe._use_runtime_gsa = True; moe_runners[li] = moe
|
||||
|
||||
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
|
||||
se.set_fused_swiglu(True)
|
||||
_load_shared_expert_weights(all_w, li, mlp_pfx, dev, se, cfg)
|
||||
se._ensure_initialized(); se._use_runtime_gsa = True; se_runners[li] = se
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
pfx = f"model.layers.{li}.self_attn.compressor"
|
||||
dev = f"cuda:{li % NUM_GPUS}"
|
||||
if li in compressors: compressors[li].load(all_w, pfx, dev=dev)
|
||||
if li in indexers: indexers[li].load(all_w, f"{pfx}.indexer", dev=dev)
|
||||
print("Components built")
|
||||
|
||||
# Embedding + tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
USER_TOKEN, ASSISTANT_TOKEN, THINK_START = 128803, 128804, 128821
|
||||
input_ids = [bos, USER_TOKEN]
|
||||
input_ids += tokenizer.encode('\n\nThe capital of France is', add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN)
|
||||
input_ids.append(THINK_START)
|
||||
print(f"Input: {len(input_ids)} tokens: {input_ids}")
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE))
|
||||
devs_list = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
||||
layer_w = _cache_layer_weights_no_experts(all_w, TEST_LAYERS, devs_list)
|
||||
del all_w; import gc; gc.collect()
|
||||
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# ================================================================
|
||||
# PHASE 1: Run full production pipeline to populate KV caches
|
||||
# ================================================================
|
||||
print(f"\n{'='*70}")
|
||||
print("PHASE 1: Populating KV caches (prefill)")
|
||||
print(f"{'='*70}")
|
||||
for pi, tid_val in enumerate(input_ids):
|
||||
t1 = time.time()
|
||||
tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE)
|
||||
pos = torch.tensor([pi], dtype=torch.long, device=DEVICE)
|
||||
tid32 = torch.tensor([tid_val], dtype=torch.int32, device=DEVICE)
|
||||
|
||||
X = mHCLayer.init_state(embed(tid))
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||||
torch.cuda.set_device(gpu)
|
||||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li), _use_fused_rmsnorm_quantize=True)
|
||||
if pi % 5 == 0:
|
||||
print(f" Token {pi}/{len(input_ids)}: {time.time()-t1:.2f}s", flush=True)
|
||||
|
||||
# Print KV cache state after prefill
|
||||
print(f"\nKV cache state after prefill ({len(input_ids)} tokens):")
|
||||
for li in range(TEST_LAYERS):
|
||||
kc = kv_caches[li]
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
print(f" L{li} (ratio={ratio}): n_comp={kc.n_comp} swa_len={kc.swa_len} "
|
||||
f"total_KV={kc.n_comp + kc.swa_len}")
|
||||
|
||||
# ================================================================
|
||||
# PHASE 2: Run ONE decode step, capturing FMHA inputs/outputs
|
||||
# ================================================================
|
||||
print(f"\n{'='*70}")
|
||||
print("PHASE 2: Decode FMHA comparison per layer")
|
||||
print(f"{'='*70}")
|
||||
|
||||
# Use a real next token — the model's own greedy output would require
|
||||
# a full forward pass to get logits. Instead, use a reasonable continuation
|
||||
# token. For "The capital of France is" → the space token or a letter.
|
||||
# Actually, we need to run the FULL decode forward pass (all layers) to get
|
||||
# the actual Q at each layer. So we'll intercept inside forward_attention.
|
||||
#
|
||||
# Approach: duplicate the forward_attention logic, capturing FMHA inputs
|
||||
# at each layer, then compare against reference SDPA.
|
||||
|
||||
# First, we need the hidden state X at the decode position.
|
||||
# We'll re-run the decode step manually, layer by layer, capturing
|
||||
# the production FMHA inputs and comparing against reference.
|
||||
|
||||
# Decode token: use the actual next position
|
||||
decode_pos = len(input_ids)
|
||||
# Use a common token — the " " (space) token
|
||||
decode_tid = tokenizer.encode(" the", add_special_tokens=False)
|
||||
if len(decode_tid) > 0:
|
||||
decode_tid = decode_tid[0]
|
||||
else:
|
||||
decode_tid = tokenizer.convert_tokens_to_ids(" ")
|
||||
print(f"Decode token: id={decode_tid} pos={decode_pos}")
|
||||
|
||||
# Get initial hidden state from embedding
|
||||
dec_tid = torch.tensor([decode_tid], dtype=torch.long, device=DEVICE)
|
||||
dec_tid32 = torch.tensor([decode_tid], dtype=torch.int32, device=DEVICE)
|
||||
dec_pos = torch.tensor([decode_pos], dtype=torch.long, device=DEVICE)
|
||||
|
||||
X = mHCLayer.init_state(embed(dec_tid))
|
||||
print(f"Initial X: shape={tuple(X.shape)} |X|={X.abs().max().item():.4f}")
|
||||
|
||||
results = {}
|
||||
|
||||
for li in range(TEST_LAYERS):
|
||||
gpu = li % NUM_GPUS
|
||||
dev = f"cuda:{gpu}"
|
||||
torch.cuda.set_device(gpu)
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(dev)
|
||||
|
||||
ratio = cr[li] if li < len(cr) else 128
|
||||
kc = kv_caches[li]
|
||||
pfx = f"model.layers.{li}.self_attn"
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
# ---- mHC pre_block + rmsnorm (same as forward_layer) ----
|
||||
attn_mhc = attn_mhcs.get(li)
|
||||
ffn_mhc = ffn_mhcs.get(li)
|
||||
attn_norm_w = attn_norms.get(li)
|
||||
ffn_norm_w = ffn_norms.get(li)
|
||||
|
||||
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X)
|
||||
ctx_a = mHCContext(B_l=B_l_a, C_l=C_l_a)
|
||||
|
||||
# Fused mHC + rmsnorm + NVFP4 quantize (production path)
|
||||
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X, A_l_a, attn_norm_w.to(dev, torch.float32))
|
||||
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
||||
|
||||
# ---- Manually replicate forward_attention to capture FMHA inputs ----
|
||||
T = x_normed.shape[0]
|
||||
pl = prod_lins[li]
|
||||
|
||||
# 1. Q projection
|
||||
q_a = pl['q_a'].run_from_quantized(x_quant_attn)
|
||||
q_norm_w = layer_w[li].get(f"{pfx}.q_a_norm.weight")
|
||||
if q_norm_w is not None:
|
||||
q_a_quant = rmsnorm_quantize_nvfp4(q_a, q_norm_w.to(dev, torch.float32))
|
||||
q_a = dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
|
||||
q = pl['q_b'].run_from_quantized(q_a_quant)
|
||||
else:
|
||||
q = pl['q_b'](q_a)
|
||||
q = unweighted_rmsnorm(q).bfloat16()
|
||||
q_heads = q.reshape(T, n_h, hd)
|
||||
q_heads = _apply_rope(q_heads, dec_pos, *rope_caches[gpu][:2], rd)
|
||||
|
||||
# 2. KV projection + cache
|
||||
kv = pl['kv'].run_from_quantized(x_quant_attn)
|
||||
kv_norm_w = layer_w[li].get(f"{pfx}.kv_norm.weight")
|
||||
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
||||
kv_3d = kv.reshape(T, 1, hd)
|
||||
kv_3d = _apply_rope(kv_3d, dec_pos, *rope_caches[gpu][:2], rd)
|
||||
kv_roped = kv_3d.reshape(T, hd)
|
||||
kc.append_swa(kv_roped, dec_pos)
|
||||
|
||||
# 3. Compressor → compressed KV
|
||||
compressor = compressors.get(li)
|
||||
indexer = indexers.get(li)
|
||||
comp_pos, block_bias = None, None
|
||||
if compressor is not None and compressor.ratio > 0:
|
||||
comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, dec_pos)
|
||||
if comp_kv_fp32 is not None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
nope_fp32 = comp_kv_fp32[:, :nope_dim].contiguous()
|
||||
rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous()
|
||||
rope_3d = rope_bf16.unsqueeze(1)
|
||||
rope_3d = _apply_rope(rope_3d, comp_pos, *rope_caches[gpu][:2], rd)
|
||||
rope_bf16 = rope_3d.squeeze(1)
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
kc.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16, comp_pos)
|
||||
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, dec_pos)
|
||||
kc.set_indexer_keys_fp8(comp_idx_kv)
|
||||
|
||||
# 4. Indexer top-k (CSA layers)
|
||||
topk_idx = None
|
||||
if indexer is not None and ratio == 4:
|
||||
topk_idx = indexer.forward(q_a, x_normed, kc, dec_pos, layer_idx=li)
|
||||
if topk_idx is not None:
|
||||
print(f" L{li} CSA: indexer topk shape={tuple(topk_idx.shape)} "
|
||||
f"range=[{topk_idx.min().item()}, {topk_idx.max().item()}] "
|
||||
f"n_comp={kc.n_comp}", flush=True)
|
||||
|
||||
# 5. Gather KV — same logic as forward_attention
|
||||
swa_kv, _swa_pos = kc.get_swa()
|
||||
swa_len = swa_kv.shape[0]
|
||||
|
||||
if kc.n_comp > 0:
|
||||
if ratio == 4:
|
||||
# CSA: gather top-k compressed rows
|
||||
assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k"
|
||||
tk = topk_idx[0].clamp(0, kc.n_comp - 1).int()
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_selective(tk)
|
||||
gather_mode = f"CSA top-k ({tk.numel()} comp + {swa_len} SWA)"
|
||||
elif ratio > 4:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_all()
|
||||
gather_mode = f"HCA all ({kc.n_comp} comp + {swa_len} SWA)"
|
||||
else:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_swa_only()
|
||||
gather_mode = f"SWA-only ({swa_len} SWA)"
|
||||
else:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kc.gather_mixed_swa_only()
|
||||
gather_mode = f"SWA-only ({swa_len} SWA)"
|
||||
|
||||
seq_len = kv_nope_scale.shape[0]
|
||||
if seq_len == 0:
|
||||
print(f" L{li}: SKIPPED (seq_len=0)")
|
||||
continue
|
||||
|
||||
print(f" L{li}: {gather_mode} → seq_len={seq_len}", flush=True)
|
||||
|
||||
# 6. Run production mixed FP8 FMHA
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
q_4d = q_heads.permute(1, 0, 2).unsqueeze(0).contiguous() # (1, n_h, T, hd)
|
||||
sinks = layer_w[li].get(f"{pfx}.sinks")
|
||||
sink_bias = None
|
||||
if sinks is not None:
|
||||
sink_bias = sinks.to(device=dev).float().reshape(n_h)
|
||||
|
||||
try:
|
||||
o_prod_4d, lse_prod = fmha_mixed_fp8_decode_raw(
|
||||
q_4d, kv_nope_fp8, kv_nope_scale, kv_rope_bf16,
|
||||
scale, attn_sink=sink_bias, rope_dim=rd)
|
||||
except Exception as e:
|
||||
print(f" L{li}: PROD FMHA FAILED: {e}")
|
||||
results[li] = {'cos': -1.0, 'error': str(e)}
|
||||
continue
|
||||
o_prod = o_prod_4d.squeeze(0) # (n_h, T, hd)
|
||||
|
||||
# 7. Reference: dequantize mixed KV to BF16, run SDPA
|
||||
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
||||
kv_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, hd)
|
||||
k_4d = kv_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) # (1, 1, N, hd)
|
||||
v_4d = k_4d.clone()
|
||||
o_ref_4d = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) # (1, H, T, hd)
|
||||
o_ref = o_ref_4d.squeeze(0) # (n_h, T, hd)
|
||||
|
||||
# 8. Compare
|
||||
cos_val = cosine(o_prod, o_ref)
|
||||
mag_prod = o_prod.float().abs().max().item()
|
||||
mag_ref = o_ref.float().abs().max().item()
|
||||
|
||||
# Per-head cosine
|
||||
o_prod_h = o_prod.float().squeeze(1) # (n_h, hd)
|
||||
o_ref_h = o_ref.float().squeeze(1)
|
||||
per_head_cos = F.cosine_similarity(o_prod_h, o_ref_h, dim=-1)
|
||||
min_head = per_head_cos.min().item()
|
||||
mean_head = per_head_cos.mean().item()
|
||||
worst_heads = per_head_cos.argsort()[:5]
|
||||
|
||||
results[li] = {
|
||||
'cos': cos_val, 'mag_prod': mag_prod, 'mag_ref': mag_ref,
|
||||
'seq_len': seq_len, 'ratio': ratio, 'gather_mode': gather_mode,
|
||||
'n_comp': kc.n_comp, 'swa_len': swa_len,
|
||||
'min_head_cos': min_head, 'mean_head_cos': mean_head,
|
||||
}
|
||||
|
||||
status = "PASS" if cos_val >= 0.999 else "FAIL"
|
||||
print(f" L{li}: {status} cos={cos_val:.6f} min_head={min_head:.6f} mean_head={mean_head:.6f} "
|
||||
f"|prod|={mag_prod:.4f} |ref|={mag_ref:.4f} seq={seq_len} {gather_mode}", flush=True)
|
||||
if cos_val < 0.999:
|
||||
print(f" Worst heads: {worst_heads.tolist()} cos={[f'{c:.4f}' for c in per_head_cos[worst_heads].tolist()]}")
|
||||
|
||||
# ---- Continue through the rest of the layer (so subsequent layers get correct X) ----
|
||||
# Apply inverse RoPE to production output
|
||||
attn_out = o_prod.permute(1, 0, 2) # (T, n_h, hd)
|
||||
attn_out = _apply_rope(attn_out, dec_pos, *rope_caches[gpu][:2], rd, inverse=True)
|
||||
|
||||
# Output projection
|
||||
wo_a_lin = pl.get('o_a')
|
||||
if wo_a_lin is not None:
|
||||
g_3d = wo_a_lin.run(attn_out)
|
||||
g_flat = g_3d.reshape(T, -1)
|
||||
F_attn = pl['o_b'](g_flat)
|
||||
else:
|
||||
hpg_fb = n_h // o_groups; gid_fb = hpg_fb * hd
|
||||
oa_full = layer_w[li].get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_full is not None:
|
||||
oa_bf = oa_full.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd)
|
||||
a_grp = a_flat.reshape(T, o_groups, gid_fb); oa_3d = oa_bf.reshape(o_groups, o_rank, gid_fb)
|
||||
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
|
||||
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
|
||||
F_attn = pl['o_b'](g_flat)
|
||||
else:
|
||||
F_attn = torch.zeros(T, H, dtype=torch.bfloat16, device=dev)
|
||||
|
||||
# mHC post_block
|
||||
X_mid = attn_mhc.post_block(X, F_attn, ctx_a)
|
||||
|
||||
# FFN mHC + MoE
|
||||
A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid)
|
||||
ctx_f = mHCContext(B_l=B_l_f, C_l=C_l_f)
|
||||
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
|
||||
X_mid, A_l_f, ffn_norm_w.to(dev, torch.float32))
|
||||
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
|
||||
F_ffn = moe_forward(x_ffn, li, moe_runners.get(li), se_runners.get(li),
|
||||
routers.get(li), dec_tid32.to(dev))
|
||||
X = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
|
||||
|
||||
# ================================================================
|
||||
# Summary
|
||||
# ================================================================
|
||||
print(f"\n{'='*70}")
|
||||
print("DECODE FMHA COMPARISON SUMMARY")
|
||||
print(f"{'='*70}")
|
||||
all_pass = True
|
||||
for li in sorted(results.keys()):
|
||||
r = results[li]
|
||||
c = r.get('cos', -1.0)
|
||||
status = "PASS" if c >= 0.999 else "FAIL"
|
||||
if c < 0.999: all_pass = False
|
||||
print(f" L{li}: {status} cos={c:.6f} seq={r.get('seq_len','?')} "
|
||||
f"mode={r.get('gather_mode','?')} "
|
||||
f"n_comp={r.get('n_comp','?')} swa={r.get('swa_len','?')}")
|
||||
|
||||
print()
|
||||
if all_pass:
|
||||
print("ALL DECODE LAYERS PASSED (cos >= 0.999)")
|
||||
else:
|
||||
print("SOME DECODE LAYERS FAILED — investigate KV gathering or compressed/SWA parity")
|
||||
print()
|
||||
print("If prefill cos was 0.999993 but decode cos < 0.999:")
|
||||
print(" → Bug is in decode-time KV gathering or compressed/SWA parity")
|
||||
print(" → Check: gather_mixed_selective (CSA), gather_mixed_all (HCA)")
|
||||
print(" → Check: SWA positions vs compressed positions (causality)")
|
||||
print(" → Check: indexer top-k indices validity")
|
||||
return 0 if all_pass else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user