265 lines
12 KiB
Python
265 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Diagnostic: track per-layer magnitudes to find residual explosion.
|
|
|
|
Runs single token "The" through all 61 layers and prints:
|
|
|X_in|, |x_normed|, |F_attn|, |X_mid|, |F_ffn|, |X_next|
|
|
for each layer.
|
|
|
|
This identifies WHERE the residual stream starts exploding.
|
|
|
|
Usage (on B200):
|
|
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
|
cd /root/dsv4-nvfp4-workspace/kernel
|
|
python3 tests/test_residual_diagnostic.py
|
|
"""
|
|
import os, sys, math, json
|
|
import torch
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
NUM_GPUS = 8
|
|
|
|
# Import helpers from the main test
|
|
from tests.test_minimal_e2e import (
|
|
FP4_LUT, dequant_nvfp4_weight, nvfp4_linear, RMSNorm,
|
|
build_rope_cache, apply_rope_partial, apply_inverse_rope,
|
|
load_weights_to_cpu, get_layer_weights
|
|
)
|
|
from single_shot_inference import mHCBlock
|
|
|
|
|
|
def main():
|
|
device = 'cuda:0'
|
|
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
|
|
n_layers = cfg["num_hidden_layers"]
|
|
H = cfg["hidden_size"]
|
|
n_h = cfg["num_attention_heads"]
|
|
hd = cfg["head_dim"]
|
|
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
|
n_hc = 4
|
|
o_rank = cfg.get("output_group_dim", 1024)
|
|
o_groups = cfg.get("num_output_groups", 16)
|
|
heads_per_group = n_h // o_groups
|
|
group_input_dim = heads_per_group * hd
|
|
|
|
print("Loading weights to CPU...")
|
|
all_weights = load_weights_to_cpu(CHECKPOINT_DIR)
|
|
|
|
embed_w = all_weights.get("model.embed_tokens.weight")
|
|
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(device))
|
|
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
|
input_ids = tokenizer.encode("The")
|
|
tid = torch.tensor([input_ids[0]], dtype=torch.long, device=device)
|
|
positions = torch.tensor([0], dtype=torch.long, device=device)
|
|
|
|
emb = embed(tid)
|
|
from dsv4.layers.mhc import mHCLayer
|
|
X = mHCLayer.init_state(emb, n_hc)
|
|
|
|
print(f"\n{'L':>3} {'|X_in|':>10} {'|x_norm|':>10} {'|F_attn|':>10} {'|X_mid|':>10} {'|F_ffn|':>10} {'|X_out|':>10} nan? inf?")
|
|
|
|
for li in range(n_layers):
|
|
gpu = li % NUM_GPUS
|
|
dev = f"cuda:{gpu}"
|
|
X = X.to(dev)
|
|
torch.cuda.set_device(gpu)
|
|
|
|
w = get_layer_weights(all_weights, li, dev)
|
|
rope_cos, rope_sin = build_rope_cache(8192, rd, dev)
|
|
positions_dev = positions.to(dev)
|
|
|
|
# Build per-layer components
|
|
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
|
attn_mhc.load_from_checkpoint(
|
|
all_weights[f"model.layers.{li}.attn_hc.fn"],
|
|
all_weights[f"model.layers.{li}.attn_hc.base"],
|
|
all_weights[f"model.layers.{li}.attn_hc.scale"])
|
|
ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
|
ffn_mhc.load_from_checkpoint(
|
|
all_weights[f"model.layers.{li}.ffn_hc.fn"],
|
|
all_weights[f"model.layers.{li}.ffn_hc.base"],
|
|
all_weights[f"model.layers.{li}.ffn_hc.scale"])
|
|
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
|
attn_norm.weight = all_weights[f"model.layers.{li}.input_layernorm.weight"].to(device=dev, dtype=torch.float32)
|
|
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
|
ffn_norm.weight = all_weights[f"model.layers.{li}.post_attention_layernorm.weight"].to(device=dev, dtype=torch.float32)
|
|
|
|
# Track magnitudes
|
|
x_in_mag = X.float().abs().max().item()
|
|
|
|
# ATTENTION
|
|
x_in_attn, attn_ctx = attn_mhc.pre_block(X)
|
|
x_normed = attn_norm.forward(x_in_attn)
|
|
x_normed_mag = x_normed.float().abs().max().item()
|
|
|
|
pre = f"model.layers.{li}.self_attn"
|
|
|
|
# Q projection
|
|
c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"],
|
|
w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"])
|
|
q_norm_w = w.get(f"{pre}.q_a_norm.weight")
|
|
if q_norm_w is not None:
|
|
c_Q_f = c_Q.float()
|
|
c_Q = (c_Q_f * c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * q_norm_w.float()).bfloat16()
|
|
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"],
|
|
w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"])
|
|
|
|
kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"],
|
|
w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"])
|
|
kv_norm_w = w.get(f"{pre}.kv_norm.weight")
|
|
if kv_norm_w is not None:
|
|
kv_f = kv.float()
|
|
kv = (kv_f * kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * kv_norm_w.float()).bfloat16()
|
|
|
|
q_heads = q.reshape(1, n_h, hd)
|
|
kv_new = kv.reshape(1, 1, hd)
|
|
q_roped = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd)
|
|
kv_roped = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd)
|
|
|
|
k_exp = kv_roped.expand(n_h, -1, -1).contiguous()
|
|
v_exp = kv_roped.expand(n_h, -1, -1).contiguous()
|
|
q_input = q_roped.permute(1, 0, 2)
|
|
scale = 1.0 / math.sqrt(hd)
|
|
|
|
sink_key = f"{pre}.sinks"
|
|
if sink_key in w:
|
|
sinks = w[sink_key].to(device=dev)
|
|
sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=dev)
|
|
sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=dev)
|
|
k_ws = torch.cat([k_exp, sink_k], dim=1)
|
|
v_ws = torch.cat([v_exp, sink_v], dim=1)
|
|
sink_bias = torch.zeros(n_h, 1, 2, dtype=torch.bfloat16, device=dev)
|
|
for h in range(n_h):
|
|
sink_bias[h, :, -1] = sinks[h]
|
|
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
|
q_input, k_ws, v_ws, attn_mask=sink_bias, scale=scale)
|
|
else:
|
|
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
|
q_input, k_exp, v_exp, scale=scale, is_causal=False)
|
|
attn_out = attn_out.permute(1, 0, 2)
|
|
|
|
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
|
|
|
|
# Output projection
|
|
attn_flat = attn_out.reshape(1, n_h * hd)
|
|
attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd)
|
|
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16()
|
|
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
|
|
grouped_out = torch.bmm(attn_grouped.permute(1, 0, 2), oa_3d.transpose(1, 2))
|
|
grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank)
|
|
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"],
|
|
w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"])
|
|
f_attn_mag = F_attn.float().abs().max().item()
|
|
|
|
X_mid = attn_mhc.post_block(X, F_attn, attn_ctx)
|
|
x_mid_mag = X_mid.float().abs().max().item()
|
|
|
|
# FFN (shared expert + routed, using the FIXED MoE loop)
|
|
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid)
|
|
x_ffn_normed = ffn_norm.forward(x_ffn)
|
|
|
|
# Routed MoE
|
|
n_experts = cfg["n_routed_experts"]
|
|
top_k = cfg.get("num_experts_per_tok", 6)
|
|
routed_scaling = cfg.get("routed_scaling_factor", 2.5)
|
|
swiglu_limit = cfg.get("swiglu_limit", 10.0)
|
|
is_hash = li < 3
|
|
|
|
if is_hash:
|
|
tid2eid_key = f"model.layers.{li}.mlp.gate.tid2eid"
|
|
if tid2eid_key in w:
|
|
tid2eid = w[tid2eid_key]
|
|
tid_val = tid.item() if tid.device == dev else tid.to(dev).item()
|
|
expert_ids = tid2eid[tid_val]
|
|
expert_weights = torch.ones(top_k, dtype=torch.float32, device=dev) / top_k
|
|
else:
|
|
is_hash = False
|
|
|
|
if not is_hash:
|
|
gate_w = w[f"model.layers.{li}.mlp.gate.weight"]
|
|
logits = torch.nn.functional.linear(x_ffn_normed, gate_w.bfloat16())
|
|
activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
|
|
e_bias_key = f"model.layers.{li}.mlp.gate.e_bias"
|
|
if e_bias_key in w:
|
|
activated = activated + w[e_bias_key].float().unsqueeze(0)
|
|
scores, indices = activated.topk(top_k, dim=-1)
|
|
unbiased = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
|
|
unbiased_scores = torch.gather(unbiased, -1, indices)
|
|
expert_weights = unbiased_scores / unbiased_scores.sum(dim=-1, keepdim=True)
|
|
expert_ids = indices[0]
|
|
|
|
expert_outputs = []
|
|
for i, eid in enumerate(expert_ids):
|
|
eid_int = eid.item()
|
|
epre = f"model.layers.{li}.mlp.experts.{eid_int}"
|
|
gate = nvfp4_linear(x_ffn_normed, w[f"{epre}.gate_proj.weight"],
|
|
w[f"{epre}.gate_proj.weight_scale"], w[f"{epre}.gate_proj.weight_scale_2"])
|
|
up = nvfp4_linear(x_ffn_normed, w[f"{epre}.up_proj.weight"],
|
|
w[f"{epre}.up_proj.weight_scale"], w[f"{epre}.up_proj.weight_scale_2"])
|
|
silu_out = torch.nn.functional.silu(gate.float())
|
|
if swiglu_limit is not None:
|
|
silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit)
|
|
up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit)
|
|
else:
|
|
up_clamped = up.float()
|
|
hidden = (silu_out * up_clamped).bfloat16()
|
|
down = nvfp4_linear(hidden, w[f"{epre}.down_proj.weight"],
|
|
w[f"{epre}.down_proj.weight_scale"], w[f"{epre}.down_proj.weight_scale_2"])
|
|
expert_outputs.append(down)
|
|
|
|
routed_out = torch.zeros_like(x_ffn_normed)
|
|
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
|
|
w_val = wt.item() if wt.dim() == 0 else wt[i].item() if wt.dim() == 1 else wt.flatten()[i].item()
|
|
routed_out = routed_out + (out.float() * w_val).bfloat16()
|
|
routed_out = (routed_out.float() * routed_scaling).bfloat16()
|
|
|
|
# Shared expert
|
|
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
|
gate = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.gate_proj.weight"],
|
|
w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"])
|
|
up = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.up_proj.weight"],
|
|
w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"])
|
|
hidden = (torch.nn.functional.silu(gate.float()).clamp(-10, 10) * up.float().clamp(-10, 10)).bfloat16()
|
|
shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"],
|
|
w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"])
|
|
|
|
F_ffn = routed_out + shared_out
|
|
f_ffn_mag = F_ffn.float().abs().max().item()
|
|
|
|
X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx)
|
|
x_out_mag = X_next.float().abs().max().item()
|
|
|
|
has_nan = torch.isnan(X_next).any().item()
|
|
has_inf = torch.isinf(X_next).any().item()
|
|
|
|
print(f"{li:3d} {x_in_mag:10.3f} {x_normed_mag:10.3f} {f_attn_mag:10.3f} {x_mid_mag:10.3f} {f_ffn_mag:10.3f} {x_out_mag:10.3f} {'NaN' if has_nan else ''} {'INF' if has_inf else ''}")
|
|
|
|
X = X_next
|
|
del w
|
|
torch.cuda.empty_cache()
|
|
|
|
# Final logits
|
|
X = X.to('cuda:0')
|
|
lm_w = all_weights.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
|
final_norm_w = all_weights.get("model.norm.weight")
|
|
x_out = X[:, 0, :]
|
|
if final_norm_w is not None:
|
|
xf = x_out.float()
|
|
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
|
x_out = (xf * rms * final_norm_w.to('cuda:0').float()).bfloat16()
|
|
logits = torch.nn.functional.linear(x_out, lm_w)
|
|
top5_vals, top5_ids = torch.topk(logits[0], 5)
|
|
top5_str = ' '.join([f'{tokenizer.decode([t.item()])}({v.item():.1f})'
|
|
for t, v in zip(top5_ids, top5_vals)])
|
|
print(f"\nTop-5: {top5_str}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|