Add single-layer trace (Phase 2.6) for detailed debugging
This commit is contained in:
@@ -61,6 +61,7 @@ PROMPT = _args.prompt or "The capital of France is"
|
||||
NUM_GPUS = 8
|
||||
SKIP_ROUTED_MOE = _args.skip_moe # If True, only use shared expert (debug)
|
||||
INVERSE_ROPE = not _args.no_inverse_rope # If False, skip inverse RoPE on attention output (diagnostic)
|
||||
LAYER_TRACE = False # If True, trace every computation step for first layer of first prefill token
|
||||
MHC_DIAG = True # If True, print per-layer mHC diagnostics (B_l row/col sums, C_l values)
|
||||
# When True: applies inverse RoPE at query position → converts absolute→relative
|
||||
# When False: leaves relative position encoding intact for output projection
|
||||
@@ -795,6 +796,123 @@ def main():
|
||||
ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w,
|
||||
final_norm_w, tokenizer)
|
||||
|
||||
# ==== Phase 2.6: Single-layer trace ====
|
||||
if True: # Always run the trace
|
||||
print(f"\n{'='*70}\nPhase 2.6: Single-Layer Trace (layer 0, first prefill token)\n{'='*70}", flush=True)
|
||||
li = 0
|
||||
dev = f"cuda:0"
|
||||
w = layer_weights[li]
|
||||
pre = f"model.layers.{li}.self_attn"
|
||||
T_dim = 1
|
||||
positions = torch.tensor([0], dtype=torch.long, device=dev)
|
||||
rope_cos, rope_sin = rope_caches[0]
|
||||
|
||||
# Start from the embedding
|
||||
tid = torch.tensor([tokenizer.encode("The")[-1]], dtype=torch.long, device=dev)
|
||||
emb = embed(tid) # (1, H)
|
||||
X = mHCBlock.init_state(emb, 4) # (1, 4, H)
|
||||
print(f" X after init_state: |X|={X.abs().max().item():.4f} stream0_mean={X[:,0,:].float().abs().mean().item():.6f}", flush=True)
|
||||
|
||||
# mHC pre_block
|
||||
attn_mhc = attn_mhc_blocks[0]
|
||||
x_in, ctx = attn_mhc.pre_block(X)
|
||||
print(f" x_in (mHC pre_block): |x_in|={x_in.abs().max().item():.4f} mean={x_in.float().abs().mean().item():.6f}", flush=True)
|
||||
B_l = ctx.B_l
|
||||
C_l = ctx.C_l
|
||||
print(f" B_l row_sums={B_l[0].sum(dim=-1).tolist()}", flush=True)
|
||||
print(f" C_l={C_l[0].tolist()}", flush=True)
|
||||
|
||||
# RMSNorm
|
||||
a_norm = attn_norms[0]
|
||||
x_normed = a_norm.forward(x_in)
|
||||
print(f" x_normed: |x|={x_normed.abs().max().item():.4f} mean={x_normed.float().abs().mean().item():.6f}", flush=True)
|
||||
|
||||
# Q projection
|
||||
c_Q = nvfp4_linear(x_normed,
|
||||
w[f"{pre}.q_a_proj.weight"],
|
||||
w[f"{pre}.q_a_proj.weight_scale"],
|
||||
w[f"{pre}.q_a_proj.weight_scale_2"])
|
||||
print(f" c_Q (q_a_proj): |c_Q|={c_Q.abs().max().item():.4f} mean={c_Q.float().abs().mean().item():.6f}", flush=True)
|
||||
|
||||
# q_a_norm
|
||||
q_norm_w = w.get(f"{pre}.q_a_norm.weight")
|
||||
if q_norm_w is not None:
|
||||
c_Q_f = c_Q.float()
|
||||
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16()
|
||||
print(f" c_Q after q_a_norm: |c_Q|={c_Q.abs().max().item():.4f}", flush=True)
|
||||
|
||||
q = nvfp4_linear(c_Q,
|
||||
w[f"{pre}.q_b_proj.weight"],
|
||||
w[f"{pre}.q_b_proj.weight_scale"],
|
||||
w[f"{pre}.q_b_proj.weight_scale_2"])
|
||||
q_heads = q.reshape(T_dim, n_h, hd)
|
||||
print(f" q_heads: |q|={q_heads.abs().max().item():.4f} mean={q_heads.float().abs().mean().item():.6f}", flush=True)
|
||||
|
||||
# KV projection
|
||||
kv = nvfp4_linear(x_normed,
|
||||
w[f"{pre}.kv_proj.weight"],
|
||||
w[f"{pre}.kv_proj.weight_scale"],
|
||||
w[f"{pre}.kv_proj.weight_scale_2"])
|
||||
print(f" kv (kv_proj): |kv|={kv.abs().max().item():.4f} mean={kv.float().abs().mean().item():.6f}", flush=True)
|
||||
|
||||
# kv_norm
|
||||
kv_norm_w = w.get(f"{pre}.kv_norm.weight")
|
||||
if kv_norm_w is not None:
|
||||
kv_f = kv.float()
|
||||
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16()
|
||||
print(f" kv after kv_norm: |kv|={kv.abs().max().item():.4f}", flush=True)
|
||||
|
||||
kv_new = kv.reshape(T_dim, 1, hd) # (1, 1, hd)
|
||||
print(f" kv_new shape: {kv_new.shape}", flush=True)
|
||||
|
||||
# Apply RoPE
|
||||
q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd)
|
||||
kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd)
|
||||
print(f" After RoPE: |q|={q_heads.abs().max().item():.4f} |kv|={kv_new.abs().max().item():.4f}", flush=True)
|
||||
|
||||
# Self-attention (single token, trivially weight=1.0)
|
||||
q_input = q_heads.permute(1, 0, 2) # (n_h, 1, hd)
|
||||
k_input = kv_new.permute(1, 0, 2) # (1, 1, hd) -> expand
|
||||
k_expanded = k_input.expand(n_h, -1, -1).contiguous()
|
||||
v_expanded = k_expanded.clone() # K=V in DSV4 MQA
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_expanded, v_expanded, scale=1.0/math.sqrt(hd))
|
||||
attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd)
|
||||
print(f" attn_out: |o|={attn_out.abs().max().item():.4f} mean={attn_out.float().abs().mean().item():.6f}", flush=True)
|
||||
|
||||
# Inverse RoPE
|
||||
if INVERSE_ROPE:
|
||||
attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd)
|
||||
print(f" After inverse RoPE: |o|={attn_out.abs().max().item():.4f}", flush=True)
|
||||
|
||||
# Output projection
|
||||
o_groups = cfg.get("num_output_groups", 16)
|
||||
o_rank = cfg.get("output_group_dim", 1024)
|
||||
heads_per_group = n_h // o_groups
|
||||
group_input_dim = heads_per_group * hd
|
||||
attn_flat = attn_out.reshape(T_dim, n_h * hd)
|
||||
attn_grouped = attn_flat.reshape(T_dim, o_groups, heads_per_group * hd)
|
||||
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16()
|
||||
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
|
||||
attn_for_bmm = attn_grouped.permute(1, 0, 2)
|
||||
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2))
|
||||
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T_dim, o_groups * o_rank)
|
||||
print(f" grouped_out (wo_a): |o|={grouped_flat.abs().max().item():.4f} mean={grouped_flat.float().abs().mean().item():.6f}", flush=True)
|
||||
|
||||
F_attn = nvfp4_linear(grouped_flat,
|
||||
w[f"{pre}.o_b_proj.weight"],
|
||||
w[f"{pre}.o_b_proj.weight_scale"],
|
||||
w[f"{pre}.o_b_proj.weight_scale_2"])
|
||||
print(f" F_attn (wo_b): |F|={F_attn.abs().max().item():.4f} mean={F_attn.float().abs().mean().item():.6f}", flush=True)
|
||||
|
||||
# mHC post_block
|
||||
X_mid = attn_mhc.post_block(X, F_attn, ctx)
|
||||
print(f" X_mid: |X|={X_mid.abs().max().item():.4f} stream0_mean={X_mid[:,0,:].float().abs().mean().item():.6f}", flush=True)
|
||||
|
||||
print(f" Layer 0 trace complete.", flush=True)
|
||||
|
||||
# ==== Phase 3: Inference ====
|
||||
print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}")
|
||||
# DeepSeek V4 chat format: <|begin▁of▁sentence|><|User|>prompt<|Assistant|>
|
||||
|
||||
Reference in New Issue
Block a user