Add NaN tracking to single_shot_inference
This commit is contained in:
@@ -228,13 +228,16 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
# ---- Reshape for attention ----
|
||||
q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd)
|
||||
|
||||
kv_dim = kv.shape[-1] # Should be hd=512 for all layer types
|
||||
|
||||
# kv_proj outputs (hd,) = 1 KV head for MQA
|
||||
# The Z (compression weights) come from compressor.gate_proj separately
|
||||
# For decode, KV is just the current token's projection
|
||||
k = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd) — MQA
|
||||
v = k.clone()
|
||||
|
||||
# Debug
|
||||
has_nan_q = torch.isnan(q_heads.float()).any().item()
|
||||
has_nan_kv = torch.isnan(k.float()).any().item()
|
||||
if li == 0:
|
||||
print(f" L{li}: q nan={has_nan_q}, kv nan={has_nan_kv}, q range=[{q_heads.float().min().item():.4f}, {q_heads.float().max().item():.4f}]")
|
||||
|
||||
# ---- Apply RoPE ----
|
||||
pos = torch.tensor([0], dtype=torch.long, device=x.device) # decode step position
|
||||
q_heads = apply_rope(q_heads, pos, rope_cos, rope_sin, rd)
|
||||
@@ -245,6 +248,11 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd)
|
||||
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd)
|
||||
|
||||
# Debug
|
||||
has_nan_attn = torch.isnan(attn_out.float()).any().item()
|
||||
if li == 0:
|
||||
print(f" L{li}: attn_out nan={has_nan_attn}, range=[{attn_out.float().min().item():.4f}, {attn_out.float().max().item():.4f}]")
|
||||
|
||||
# ---- Output projection: wo_a (BF16 batched matmul) → wo_b (NVFP4) ----
|
||||
# wo_a: grouped linear — input per group: (heads_per_group * hd) → o_lora_rank
|
||||
# Implemented as batched matmul: (n_groups, heads_per_group*hd) × (n_groups, heads_per_group*hd, o_rank)
|
||||
|
||||
Reference in New Issue
Block a user