469 lines
25 KiB
Python
469 lines
25 KiB
Python
#!/usr/bin/env python3
|
|
"""PART A — Decode Diagnostics: Production pipeline per-layer diagnostics.
|
|
|
|
This test runs the FULL production pipeline (single_shot_inference.py forward_layer)
|
|
for prefill tokens and the first decode step, printing per-layer diagnostics:
|
|
- |X| per layer (mHC residual growth)
|
|
- |F_attn| and |F_ffn| magnitudes
|
|
- Compressed/SWA visible range diagnostics (causality, overlap)
|
|
- KV cache state (n_comp, swa_len)
|
|
|
|
Production values: HD=512, NOPE=448, ROPE=64, H=128, 61 layers, 8 GPUs, 384 experts.
|
|
"""
|
|
import os, sys, json, math, time
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
CHECKPOINT_DIR = os.environ.get(
|
|
"CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
|
NUM_GPUS = int(os.environ.get("NUM_GPUS", "8"))
|
|
DEVICE = "cuda:0"
|
|
TEST_LAYERS = int(os.environ.get("TEST_LAYERS", "5"))
|
|
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
print("=" * 70)
|
|
print("PART A — DECODE DIAGNOSTICS (Production Pipeline)")
|
|
print("=" * 70)
|
|
|
|
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
n_layers = cfg["num_hidden_layers"]
|
|
H = cfg["hidden_size"]
|
|
hd = cfg["head_dim"]
|
|
n_h = cfg["num_attention_heads"]
|
|
rd = cfg.get("qk_rope_head_dim", 64)
|
|
nope_dim = hd - rd
|
|
cr = cfg.get("compress_ratios", [128] * n_layers)
|
|
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}, nope_dim={nope_dim}")
|
|
print(f"Compress ratios (first {TEST_LAYERS}): {cr[:TEST_LAYERS]}")
|
|
|
|
from single_shot_inference import (
|
|
load_all_weights, make_nvfp4_linear, get_nvfp4_weight,
|
|
rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache,
|
|
KVCache, Compressor, Indexer, forward_layer, forward_attention, moe_forward,
|
|
_load_moe_weights_stacked, _load_shared_expert_weights,
|
|
_cache_layer_weights_no_experts,
|
|
)
|
|
from dsv4.layers.mhc import mHCLayer, mHCContext
|
|
from dsv4.layers.router import Router
|
|
from dsv4.layers.moe import Nvfp4MoE
|
|
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
|
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
|
|
from dsv4.layers.linear import Nvfp4Linear
|
|
from dsv4.ops.quantize import (
|
|
rmsnorm_quantize_nvfp4, mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4,
|
|
quantize_to_nvfp4,
|
|
)
|
|
|
|
print("Loading weights...")
|
|
all_w = load_all_weights(CHECKPOINT_DIR)
|
|
|
|
o_groups = cfg.get("o_groups", 16)
|
|
o_rank = cfg.get("o_lora_rank", 1024)
|
|
n_ih = cfg.get("index_n_heads", 64)
|
|
ihd = cfg.get("index_head_dim", 128)
|
|
itk = cfg.get("index_topk", 1024)
|
|
|
|
rope_caches = {g: build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1)
|
|
for g in range(NUM_GPUS)}
|
|
|
|
# Build production components for TEST_LAYERS
|
|
prod_lins, attn_mhcs, ffn_mhcs = {}, {}, {}
|
|
attn_norms, ffn_norms = {}, {}
|
|
compressors, indexers, kv_caches = {}, {}, {}
|
|
routers, moe_runners, se_runners = {}, {}, {}
|
|
|
|
for li in range(TEST_LAYERS):
|
|
gpu = li % NUM_GPUS
|
|
dev = f"cuda:{gpu}"
|
|
torch.cuda.set_device(gpu)
|
|
pfx = f"model.layers.{li}.self_attn"
|
|
mlp_pfx = f"model.layers.{li}.mlp"
|
|
ratio = cr[li] if li < len(cr) else 128
|
|
|
|
pl = {}
|
|
pl['q_a'] = make_nvfp4_linear(H, 1536, dev, all_w, pfx, 'q_a_proj')
|
|
pl['q_b'] = make_nvfp4_linear(1536, H * hd, dev, all_w, pfx, 'q_b_proj')
|
|
pl['kv'] = make_nvfp4_linear(H, hd, dev, all_w, pfx, 'kv_proj')
|
|
hpg = n_h // o_groups
|
|
wo_a = Nvfp4GroupedLinear(n_local_groups=o_groups, heads_per_group=hpg,
|
|
head_dim=hd, o_lora_rank=o_rank, max_num_tokens=8192, device=dev)
|
|
oa_w, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
|
|
if oa_w is not None and oa_ws is not None:
|
|
wo_a.load_nvfp4_weight(oa_w.to(dev), oa_ws.to(dev),
|
|
oa_ws2.to(dev) if oa_ws2 is not None else None,
|
|
oa_isc.to(dev) if oa_isc is not None else None)
|
|
else:
|
|
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
|
|
if oa_bf is not None:
|
|
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
|
|
pl['o_a'] = wo_a; wo_a._use_runtime_gsa = True
|
|
pl['o_b'] = make_nvfp4_linear(o_groups * o_rank, H, dev, all_w, pfx, 'o_b_proj')
|
|
prod_lins[li] = pl
|
|
|
|
for tag, blocks, fn_s, base_s, scale_s in [
|
|
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn",
|
|
f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
|
|
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn",
|
|
f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
|
|
]:
|
|
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
|
|
if fn is not None and base is not None and scale is not None:
|
|
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
|
|
n = 4
|
|
m.load_weights(
|
|
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
|
|
W_comb=fn[2*n:].to(dev, torch.float32),
|
|
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
|
|
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
|
|
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
|
|
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item())
|
|
blocks[li] = m
|
|
|
|
an_k = f"model.layers.{li}.input_layernorm.weight"
|
|
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
|
|
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
|
|
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
|
|
|
|
max_comp = (8192 + ratio - 1) // ratio if ratio > 0 else 0
|
|
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp,
|
|
device=dev, indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd)
|
|
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
|
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
|
|
|
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{mlp_pfx}.gate.tid2eid" in all_w)
|
|
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
|
|
top_k=cfg.get("num_experts_per_tok", 6),
|
|
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
|
|
mode="hash" if is_hash else "dense",
|
|
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
|
|
if is_hash:
|
|
router.load_weights(hash_lut=all_w[f"{mlp_pfx}.gate.tid2eid"].to(dev, torch.int32))
|
|
else:
|
|
eb = all_w.get(f"{mlp_pfx}.gate.e_score_correction_bias")
|
|
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, mlp_pfx, 'gate')
|
|
E = cfg["n_routed_experts"]
|
|
if gate_w is not None and gate_ws is not None:
|
|
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
|
gate_lin.fp4 = [gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)]
|
|
gate_lin.sf = [gate_ws.to(dev)]
|
|
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
|
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
|
gate_lin.gs = [1.0]
|
|
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
|
gate_lin._activation_global_scale = isc_v
|
|
gate_lin._use_runtime_gsa = True
|
|
gate_lin.finalize_weights()
|
|
router.load_nvfp4_gate(gate_lin)
|
|
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
|
else:
|
|
gw = all_w.get(f"{mlp_pfx}.gate.weight")
|
|
if gw is not None:
|
|
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
|
g_bf16 = g_bf16.bfloat16().to(dev)
|
|
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
|
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
|
gate_lin.fp4 = [g_fp4]
|
|
gate_lin.sf = [g_sf]
|
|
gate_lin.gs = [g_gs]
|
|
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
|
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
|
gate_lin._use_runtime_gsa = True
|
|
gate_lin.finalize_weights()
|
|
router.load_nvfp4_gate(gate_lin)
|
|
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
|
router.finalize_weights(); routers[li] = router
|
|
|
|
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
|
intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
|
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
|
|
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)); moe.set_fused_swiglu(True)
|
|
_load_moe_weights_stacked(all_w, li, mlp_pfx, dev, moe, cfg)
|
|
moe._ensure_stacked(); moe._use_runtime_gsa = True; moe_runners[li] = moe
|
|
|
|
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
|
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
|
|
se.set_fused_swiglu(True)
|
|
_load_shared_expert_weights(all_w, li, mlp_pfx, dev, se, cfg)
|
|
se._ensure_initialized(); se._use_runtime_gsa = True; se_runners[li] = se
|
|
torch.cuda.empty_cache()
|
|
|
|
for li in range(TEST_LAYERS):
|
|
pfx = f"model.layers.{li}.self_attn.compressor"
|
|
dev = f"cuda:{li % NUM_GPUS}"
|
|
if li in compressors: compressors[li].load(all_w, pfx, dev=dev)
|
|
if li in indexers: indexers[li].load(all_w, f"{pfx}.indexer", dev=dev)
|
|
# Verify compressor kv_norm_w loaded correctly
|
|
for li in range(TEST_LAYERS):
|
|
if li in compressors and compressors[li].kv_norm_w is not None:
|
|
n = compressors[li].kv_norm_w
|
|
print(f" L{li} compressor kv_norm_w: shape={tuple(n.shape)} |w|={n.abs().max().item():.4f}", flush=True)
|
|
elif li in compressors:
|
|
print(f" L{li} compressor kv_norm_w: MISSING!", flush=True)
|
|
print("Production components built")
|
|
|
|
# Embedding + tokenizer
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
|
bos = tokenizer.bos_token_id or 0
|
|
USER_TOKEN, ASSISTANT_TOKEN, THINK_START = 128803, 128804, 128821
|
|
input_ids = [bos, USER_TOKEN]
|
|
input_ids += tokenizer.encode('\n\nThe capital of France is', add_special_tokens=False)
|
|
input_ids.append(ASSISTANT_TOKEN)
|
|
input_ids.append(THINK_START)
|
|
print(f"Input: {len(input_ids)} tokens")
|
|
|
|
torch.cuda.set_device(0)
|
|
embed_w = all_w.get("model.embed_tokens.weight")
|
|
prod_embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE))
|
|
devs_list = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
|
layer_w = _cache_layer_weights_no_experts(all_w, TEST_LAYERS, devs_list)
|
|
del all_w; import gc; gc.collect()
|
|
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
|
|
torch.cuda.set_device(0)
|
|
|
|
# ================================================================
|
|
# PHASE 1: Prefill — production, with per-layer |X| tracking
|
|
# ================================================================
|
|
print(f"\n{'='*70}")
|
|
print("PHASE 1: Prefill — PRODUCTION (per-layer |X| tracking)")
|
|
print(f"{'='*70}")
|
|
|
|
print(f"\n {'tok':>3} {'L':>3} {'|X_in|':>12} {'|X_out|':>12} {'ratio':>5} {'n_comp':>6} {'swa':>4}")
|
|
print(f" {'---':>3} {'---':>3} {'---':>12} {'---':>12} {'---':>5} {'---':>6} {'---':>4}")
|
|
|
|
for pi, tid_val in enumerate(input_ids):
|
|
t1 = time.time()
|
|
tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE)
|
|
pos = torch.tensor([pi], dtype=torch.long, device=DEVICE)
|
|
tid32 = torch.tensor([tid_val], dtype=torch.int32, device=DEVICE)
|
|
|
|
X = mHCLayer.init_state(prod_embed(tid))
|
|
for li in range(TEST_LAYERS):
|
|
gpu = li % NUM_GPUS
|
|
dev = f"cuda:{gpu}"
|
|
if X.device != torch.device(dev): X = X.to(dev)
|
|
torch.cuda.set_device(gpu)
|
|
|
|
X_prev = X.clone() # Save for blowup diagnostics
|
|
X_in_mag = X.abs().max().item()
|
|
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
|
attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li),
|
|
kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),
|
|
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
|
prod_lin=prod_lins.get(li), _use_fused_rmsnorm_quantize=True)
|
|
X_out_mag = X.abs().max().item() if X.device == torch.device(DEVICE) else X.to(DEVICE).abs().max().item()
|
|
|
|
kc = kv_caches[li]
|
|
ratio = cr[li] if li < len(cr) else 128
|
|
# Print per-token, per-layer for first 3 tokens, then only first and last layer
|
|
if pi < 3 or pi == len(input_ids) - 1:
|
|
print(f" {pi:>3} {li:>3} {X_in_mag:>12.2f} {X_out_mag:>12.2f} {ratio:>5} {kc.n_comp:>6} {kc.swa_len:>4}", flush=True)
|
|
|
|
# Early abort if |X| blows up — run detailed diagnostics on THIS layer
|
|
if X_out_mag > 1e6:
|
|
print(f" *** BLOWUP at token {pi} layer {li}: |X|={X_out_mag:.2e} ***", flush=True)
|
|
print(f" Re-running layer {li} with detailed diagnostics...", flush=True)
|
|
# Re-run the SAME input through forward_layer but capture intermediates
|
|
X_diag = X_prev.clone() # X before this layer
|
|
attn_mhc_d = attn_mhcs.get(li)
|
|
ffn_mhc_d = ffn_mhcs.get(li)
|
|
A_l_a, B_l_a, C_l_a = attn_mhc_d._dynamic_params(X_diag)
|
|
ctx_a_d = mHCContext(B_l=B_l_a, C_l=C_l_a)
|
|
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
|
|
X_diag, A_l_a, attn_norms.get(li).to(dev, torch.float32))
|
|
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
|
print(f" |x_normed|={x_normed.abs().max().item():.2f} gsa={x_quant_attn.gsa}", flush=True)
|
|
# Run compressor and print raw output
|
|
comp_diag = compressors.get(li)
|
|
if comp_diag is not None:
|
|
comp_kv_d, comp_pos_d, _ = comp_diag.forward(x_normed, pos)
|
|
if comp_kv_d is not None:
|
|
print(f" Compressor output: |comp_kv|={comp_kv_d.abs().max().item():.2f} shape={tuple(comp_kv_d.shape)}", flush=True)
|
|
else:
|
|
print(f" Compressor output: None (n_complete=0)", flush=True)
|
|
# Print KV cache state BEFORE calling forward_attention
|
|
kc_diag = kv_caches[li]
|
|
swa_kv_d, swa_pos_d = kc_diag.get_swa()
|
|
print(f" KV: n_comp={kc_diag.n_comp} swa_len={swa_kv_d.shape[0]}", flush=True)
|
|
# Gather KV and print
|
|
ratio_diag = cr[li] if li < len(cr) else 128
|
|
seq_len_d = 0
|
|
if kc_diag.n_comp > 0:
|
|
if ratio_diag == 4:
|
|
# Need to compute indexer top-k first
|
|
# Run Q projection to get q_a
|
|
pl_diag = prod_lins.get(li)
|
|
q_a_d = pl_diag['q_a'].run_from_quantized(x_quant_attn)
|
|
q_norm_w_d = layer_w[li].get(f"model.layers.{li}.self_attn.q_a_norm.weight")
|
|
if q_norm_w_d is not None:
|
|
q_a_quant_d = rmsnorm_quantize_nvfp4(q_a_d, q_norm_w_d.to(dev, torch.float32))
|
|
q_a_d = dequantize_nvfp4(q_a_quant_d.x_fp4, q_a_quant_d.x_sf, q_a_quant_d.gsa)
|
|
topk_idx_d = None
|
|
if indexers.get(li) is not None:
|
|
topk_idx_d = indexers[li].forward(q_a_d, x_normed, kc_diag, pos, layer_idx=li)
|
|
if topk_idx_d is not None:
|
|
tk_d = topk_idx_d[0].clamp(0, kc_diag.n_comp - 1).int()
|
|
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_selective(tk_d)
|
|
print(f" CSA topk: {tk_d.tolist()[:10]}", flush=True)
|
|
else:
|
|
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_swa_only()
|
|
elif ratio_diag > 4:
|
|
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_all()
|
|
else:
|
|
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_swa_only()
|
|
else:
|
|
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_swa_only()
|
|
seq_len_d = kv_nope_scale_d.shape[0]
|
|
nope_max = kv_nope_fp8_d.view(torch.float8_e4m3fn).float().abs().max().item()
|
|
scale_max = kv_nope_scale_d.abs().max().item()
|
|
rope_max = kv_rope_bf16_d.float().abs().max().item()
|
|
print(f" Gathered KV: seq_len={seq_len_d} |nope_fp8|={nope_max:.2f} |nope_scale|={scale_max:.6f} |rope_bf16|={rope_max:.2f}", flush=True)
|
|
nope_dequant_max = (kv_nope_fp8_d.view(torch.float8_e4m3fn).float() * kv_nope_scale_d.unsqueeze(-1).float()).abs().max().item()
|
|
print(f" |nope_dequant_max|={nope_dequant_max:.4f}", flush=True)
|
|
# Now run FMHA
|
|
F_attn_d, q_a_d = forward_attention(
|
|
x_normed, layer_w[li], li, cfg, *rope_caches[gpu],
|
|
kv_caches[li], pos, compressors.get(li), indexers.get(li), prod_lins.get(li),
|
|
x_quant=x_quant_attn)
|
|
print(f" |F_attn|={F_attn_d.abs().max().item():.2f}", flush=True)
|
|
# Check if Q heads are reasonable
|
|
q_heads_diag = pl_diag['q_b'].run_from_quantized(rmsnorm_quantize_nvfp4(q_a_d, layer_w[li].get(f"model.layers.{li}.self_attn.q_a_norm.weight").to(dev, torch.float32)))
|
|
q_heads_diag = unweighted_rmsnorm(q_heads_diag).bfloat16()
|
|
print(f" |Q_heads|={q_heads_diag.abs().max().item():.4f}", flush=True)
|
|
X_mid_d = attn_mhc_d.post_block(X_diag, F_attn_d, ctx_a_d)
|
|
print(f" |X_mid|={X_mid_d.abs().max().item():.2f} B_l_row=[{B_l_a.sum(-1).min().item():.4f},{B_l_a.sum(-1).max().item():.4f}] C_l=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]", flush=True)
|
|
A_l_f, B_l_f, C_l_f = ffn_mhc_d._dynamic_params(X_mid_d)
|
|
ctx_f_d = mHCContext(B_l=B_l_f, C_l=C_l_f)
|
|
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
|
|
X_mid_d, A_l_f, ffn_norms.get(li).to(dev, torch.float32))
|
|
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
|
|
F_ffn_d = moe_forward(x_ffn, li, moe_runners.get(li), se_runners.get(li),
|
|
routers.get(li), tid32.to(dev))
|
|
print(f" |F_ffn|={F_ffn_d.abs().max().item():.2f}", flush=True)
|
|
X_next_d = ffn_mhc_d.post_block(X_mid_d, F_ffn_d, ctx_f_d)
|
|
print(f" |X_next|={X_next_d.abs().max().item():.2e}", flush=True)
|
|
# Check per-component magnitudes
|
|
BX = torch.bmm(ctx_a_d.B_l.transpose(-1, -2), X_diag.float())
|
|
CF = ctx_a_d.C_l.unsqueeze(-1) * F_attn_d.unsqueeze(1)
|
|
print(f" |B@X|={BX.abs().max().item():.2f} |C*F|={CF.abs().max().item():.2f}", flush=True)
|
|
BX_f = torch.bmm(ctx_f_d.B_l.transpose(-1, -2), X_mid_d.float())
|
|
CF_f = ctx_f_d.C_l.unsqueeze(-1) * F_ffn_d.unsqueeze(1)
|
|
print(f" FFN: |B@X|={BX_f.abs().max().item():.2f} |C*F|={CF_f.abs().max().item():.2f}", flush=True)
|
|
return 1
|
|
|
|
if pi % 5 == 0:
|
|
print(f" Token {pi}/{len(input_ids)} done: {time.time()-t1:.2f}s |X|={X.to(DEVICE).abs().max().item():.2f}", flush=True)
|
|
|
|
# KV cache state
|
|
print(f"\nProduction KV cache state after prefill ({len(input_ids)} tokens):")
|
|
for li in range(TEST_LAYERS):
|
|
kc = kv_caches[li]
|
|
ratio = cr[li] if li < len(cr) else 128
|
|
print(f" L{li} (ratio={ratio}): n_comp={kc.n_comp} swa_len={kc.swa_len} total_KV={kc.n_comp + kc.swa_len}")
|
|
|
|
# ================================================================
|
|
# PHASE 2: Decode step — per-layer diagnostics
|
|
# ================================================================
|
|
print(f"\n{'='*70}")
|
|
print("PHASE 2: Decode step — per-layer diagnostics")
|
|
print(f"{'='*70}")
|
|
|
|
decode_pos = len(input_ids)
|
|
decode_tid = tokenizer.encode(" the", add_special_tokens=False)
|
|
decode_tid = decode_tid[0] if len(decode_tid) > 0 else 2
|
|
|
|
dec_tid = torch.tensor([decode_tid], dtype=torch.long, device=DEVICE)
|
|
dec_tid32 = torch.tensor([decode_tid], dtype=torch.int32, device=DEVICE)
|
|
dec_pos = torch.tensor([decode_pos], dtype=torch.long, device=DEVICE)
|
|
|
|
X = mHCLayer.init_state(prod_embed(dec_tid))
|
|
print(f"\nInitial X: shape={tuple(X.shape)} |X|={X.abs().max().item():.6f}")
|
|
|
|
print(f"\n {'L':>3} {'ratio':>5} {'|X_in|':>12} {'|X_out|':>12} {'|F_attn|':>10} {'|F_ffn|':>10} {'n_comp':>6} {'swa':>4} {'mode':>8} {'leak':>5}")
|
|
print(f" {'-'*3} {'-'*5} {'-'*12} {'-'*12} {'-'*10} {'-'*10} {'-'*6} {'-'*4} {'-'*8} {'-'*5}")
|
|
|
|
for li in range(TEST_LAYERS):
|
|
gpu = li % NUM_GPUS
|
|
dev = f"cuda:{gpu}"
|
|
torch.cuda.set_device(gpu)
|
|
if X.device != torch.device(dev): X = X.to(dev)
|
|
|
|
ratio = cr[li] if li < len(cr) else 128
|
|
kc = kv_caches[li]
|
|
X_in_mag = X.abs().max().item()
|
|
|
|
# Production forward — capture intermediates
|
|
attn_mhc = attn_mhcs.get(li)
|
|
ffn_mhc = ffn_mhcs.get(li)
|
|
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X)
|
|
ctx_a = mHCContext(B_l=B_l_a, C_l=C_l_a)
|
|
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
|
|
X, A_l_a, attn_norms.get(li).to(dev, torch.float32))
|
|
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
|
|
|
F_attn, q_a = forward_attention(
|
|
x_normed, layer_w[li], li, cfg, *rope_caches[gpu],
|
|
kc, dec_pos, compressors.get(li), indexers.get(li), prod_lins.get(li),
|
|
x_quant=x_quant_attn)
|
|
X_mid = attn_mhc.post_block(X, F_attn, ctx_a)
|
|
|
|
A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid)
|
|
ctx_f = mHCContext(B_l=B_l_f, C_l=C_l_f)
|
|
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
|
|
X_mid, A_l_f, ffn_norms.get(li).to(dev, torch.float32))
|
|
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
|
|
F_ffn = moe_forward(x_ffn, li, moe_runners.get(li), se_runners.get(li),
|
|
routers.get(li), dec_tid32.to(dev))
|
|
X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
|
|
|
|
X_out_mag = X_next.to(DEVICE).abs().max().item()
|
|
f_attn_mag = F_attn.to(DEVICE).abs().max().item()
|
|
f_ffn_mag = F_ffn.to(DEVICE).abs().max().item()
|
|
|
|
swa_kv, swa_pos = kc.get_swa()
|
|
swa_len = swa_kv.shape[0]
|
|
n_comp = kc.n_comp
|
|
mode = "CSA" if ratio == 4 else ("HCA" if ratio > 4 else "SWA")
|
|
|
|
# Causality check
|
|
future_leak = False
|
|
if n_comp > 0 and kc.comp_pos is not None and kc.comp_pos.numel() > 0:
|
|
visible_comp_pos = kc.comp_pos[:n_comp]
|
|
future_leak = (visible_comp_pos >= decode_pos).any().item()
|
|
|
|
print(f" {li:>3} {ratio:>5} {X_in_mag:>12.2f} {X_out_mag:>12.2f} "
|
|
f"{f_attn_mag:>10.2f} {f_ffn_mag:>10.2f} {n_comp:>6} {swa_len:>4} {mode:>8} "
|
|
f"{'YES!' if future_leak else 'no':>5}")
|
|
|
|
# mHC diagnostics
|
|
B_a = B_l_a
|
|
print(f" mHC: B_l row_sum=[{B_a.sum(-1).min().item():.4f},{B_a.sum(-1).max().item():.4f}] "
|
|
f"col_sum=[{B_a.sum(-2).min().item():.4f},{B_a.sum(-2).max().item():.4f}] "
|
|
f"A=[{A_l_a.min().item():.4f},{A_l_a.max().item():.4f}] "
|
|
f"C=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]")
|
|
|
|
# CSA specifics
|
|
if ratio == 4 and n_comp > 0:
|
|
print(f" CSA: n_comp={n_comp} swa_len={swa_len} total_attend={n_comp + swa_len}")
|
|
|
|
X = X_next
|
|
|
|
# Summary
|
|
print(f"\n{'='*70}")
|
|
print("PART A SUMMARY")
|
|
print(f"{'='*70}")
|
|
print("Production pipeline diagnostics complete.")
|
|
print("Check the |X| values above for:")
|
|
print(" 1. Exponential growth (mHC residual blowup)")
|
|
print(" 2. Sudden jumps (NVFP4 quantization error)")
|
|
print(" 3. NaN/Inf (numerical instability)")
|
|
print(" 4. future_leak=YES (causality violation in compressed KV)")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|