diff --git a/single_shot_inference.py b/single_shot_inference.py index b0cf2cc8..74fa9d91 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -61,6 +61,7 @@ PROMPT = _args.prompt or "The capital of France is" NUM_GPUS = 8 SKIP_ROUTED_MOE = _args.skip_moe # If True, only use shared expert (debug) INVERSE_ROPE = not _args.no_inverse_rope # If False, skip inverse RoPE on attention output (diagnostic) +LAYER_TRACE = False # If True, trace every computation step for first layer of first prefill token MHC_DIAG = True # If True, print per-layer mHC diagnostics (B_l row/col sums, C_l values) # When True: applies inverse RoPE at query position → converts absolute→relative # When False: leaves relative position encoding intact for output projection @@ -795,6 +796,123 @@ def main(): ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w, final_norm_w, tokenizer) + # ==== Phase 2.6: Single-layer trace ==== + if True: # Always run the trace + print(f"\n{'='*70}\nPhase 2.6: Single-Layer Trace (layer 0, first prefill token)\n{'='*70}", flush=True) + li = 0 + dev = f"cuda:0" + w = layer_weights[li] + pre = f"model.layers.{li}.self_attn" + T_dim = 1 + positions = torch.tensor([0], dtype=torch.long, device=dev) + rope_cos, rope_sin = rope_caches[0] + + # Start from the embedding + tid = torch.tensor([tokenizer.encode("The")[-1]], dtype=torch.long, device=dev) + emb = embed(tid) # (1, H) + X = mHCBlock.init_state(emb, 4) # (1, 4, H) + print(f" X after init_state: |X|={X.abs().max().item():.4f} stream0_mean={X[:,0,:].float().abs().mean().item():.6f}", flush=True) + + # mHC pre_block + attn_mhc = attn_mhc_blocks[0] + x_in, ctx = attn_mhc.pre_block(X) + print(f" x_in (mHC pre_block): |x_in|={x_in.abs().max().item():.4f} mean={x_in.float().abs().mean().item():.6f}", flush=True) + B_l = ctx.B_l + C_l = ctx.C_l + print(f" B_l row_sums={B_l[0].sum(dim=-1).tolist()}", flush=True) + print(f" C_l={C_l[0].tolist()}", flush=True) + + # RMSNorm + a_norm = attn_norms[0] + x_normed = a_norm.forward(x_in) + print(f" x_normed: |x|={x_normed.abs().max().item():.4f} mean={x_normed.float().abs().mean().item():.6f}", flush=True) + + # Q projection + c_Q = nvfp4_linear(x_normed, + w[f"{pre}.q_a_proj.weight"], + w[f"{pre}.q_a_proj.weight_scale"], + w[f"{pre}.q_a_proj.weight_scale_2"]) + print(f" c_Q (q_a_proj): |c_Q|={c_Q.abs().max().item():.4f} mean={c_Q.float().abs().mean().item():.6f}", flush=True) + + # q_a_norm + q_norm_w = w.get(f"{pre}.q_a_norm.weight") + if q_norm_w is not None: + c_Q_f = c_Q.float() + c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16() + print(f" c_Q after q_a_norm: |c_Q|={c_Q.abs().max().item():.4f}", flush=True) + + q = nvfp4_linear(c_Q, + w[f"{pre}.q_b_proj.weight"], + w[f"{pre}.q_b_proj.weight_scale"], + w[f"{pre}.q_b_proj.weight_scale_2"]) + q_heads = q.reshape(T_dim, n_h, hd) + print(f" q_heads: |q|={q_heads.abs().max().item():.4f} mean={q_heads.float().abs().mean().item():.6f}", flush=True) + + # KV projection + kv = nvfp4_linear(x_normed, + w[f"{pre}.kv_proj.weight"], + w[f"{pre}.kv_proj.weight_scale"], + w[f"{pre}.kv_proj.weight_scale_2"]) + print(f" kv (kv_proj): |kv|={kv.abs().max().item():.4f} mean={kv.float().abs().mean().item():.6f}", flush=True) + + # kv_norm + kv_norm_w = w.get(f"{pre}.kv_norm.weight") + if kv_norm_w is not None: + kv_f = kv.float() + kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16() + print(f" kv after kv_norm: |kv|={kv.abs().max().item():.4f}", flush=True) + + kv_new = kv.reshape(T_dim, 1, hd) # (1, 1, hd) + print(f" kv_new shape: {kv_new.shape}", flush=True) + + # Apply RoPE + q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd) + kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd) + print(f" After RoPE: |q|={q_heads.abs().max().item():.4f} |kv|={kv_new.abs().max().item():.4f}", flush=True) + + # Self-attention (single token, trivially weight=1.0) + q_input = q_heads.permute(1, 0, 2) # (n_h, 1, hd) + k_input = kv_new.permute(1, 0, 2) # (1, 1, hd) -> expand + k_expanded = k_input.expand(n_h, -1, -1).contiguous() + v_expanded = k_expanded.clone() # K=V in DSV4 MQA + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_input, k_expanded, v_expanded, scale=1.0/math.sqrt(hd)) + attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd) + print(f" attn_out: |o|={attn_out.abs().max().item():.4f} mean={attn_out.float().abs().mean().item():.6f}", flush=True) + + # Inverse RoPE + if INVERSE_ROPE: + attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd) + print(f" After inverse RoPE: |o|={attn_out.abs().max().item():.4f}", flush=True) + + # Output projection + o_groups = cfg.get("num_output_groups", 16) + o_rank = cfg.get("output_group_dim", 1024) + heads_per_group = n_h // o_groups + group_input_dim = heads_per_group * hd + attn_flat = attn_out.reshape(T_dim, n_h * hd) + attn_grouped = attn_flat.reshape(T_dim, o_groups, heads_per_group * hd) + oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() + oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) + attn_for_bmm = attn_grouped.permute(1, 0, 2) + grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) + grouped_flat = grouped_out.permute(1, 0, 2).reshape(T_dim, o_groups * o_rank) + print(f" grouped_out (wo_a): |o|={grouped_flat.abs().max().item():.4f} mean={grouped_flat.float().abs().mean().item():.6f}", flush=True) + + F_attn = nvfp4_linear(grouped_flat, + w[f"{pre}.o_b_proj.weight"], + w[f"{pre}.o_b_proj.weight_scale"], + w[f"{pre}.o_b_proj.weight_scale_2"]) + print(f" F_attn (wo_b): |F|={F_attn.abs().max().item():.4f} mean={F_attn.float().abs().mean().item():.6f}", flush=True) + + # mHC post_block + X_mid = attn_mhc.post_block(X, F_attn, ctx) + print(f" X_mid: |X|={X_mid.abs().max().item():.4f} stream0_mean={X_mid[:,0,:].float().abs().mean().item():.6f}", flush=True) + + print(f" Layer 0 trace complete.", flush=True) + # ==== Phase 3: Inference ==== print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}") # DeepSeek V4 chat format: <|begin▁of▁sentence|><|User|>prompt<|Assistant|>