diff --git a/single_shot_inference.py b/single_shot_inference.py index 1ed2d3b7..79f4cb8d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -228,13 +228,16 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin): # ---- Reshape for attention ---- q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd) - kv_dim = kv.shape[-1] # Should be hd=512 for all layer types - - # kv_proj outputs (hd,) = 1 KV head for MQA - # The Z (compression weights) come from compressor.gate_proj separately + # For decode, KV is just the current token's projection k = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd) — MQA v = k.clone() + # Debug + has_nan_q = torch.isnan(q_heads.float()).any().item() + has_nan_kv = torch.isnan(k.float()).any().item() + if li == 0: + print(f" L{li}: q nan={has_nan_q}, kv nan={has_nan_kv}, q range=[{q_heads.float().min().item():.4f}, {q_heads.float().max().item():.4f}]") + # ---- Apply RoPE ---- pos = torch.tensor([0], dtype=torch.long, device=x.device) # decode step position q_heads = apply_rope(q_heads, pos, rope_cos, rope_sin, rd) @@ -245,6 +248,11 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin): attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd) attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd) + # Debug + has_nan_attn = torch.isnan(attn_out.float()).any().item() + if li == 0: + print(f" L{li}: attn_out nan={has_nan_attn}, range=[{attn_out.float().min().item():.4f}, {attn_out.float().max().item():.4f}]") + # ---- Output projection: wo_a (BF16 batched matmul) → wo_b (NVFP4) ---- # wo_a: grouped linear — input per group: (heads_per_group * hd) → o_lora_rank # Implemented as batched matmul: (n_groups, heads_per_group*hd) × (n_groups, heads_per_group*hd, o_rank)