PART A: fix KV diagnostics — compute q_a before indexer, add Q_heads magnitude check

This commit is contained in:
2026-06-03 06:33:51 +00:00
parent 86e59c16c5
commit f2c1b3afd5

View File

@@ -269,20 +269,30 @@ def main():
X_diag, A_l_a, attn_norms.get(li).to(dev, torch.float32))
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
print(f" |x_normed|={x_normed.abs().max().item():.2f} gsa={x_quant_attn.gsa}", flush=True)
# Print KV cache state
# Print KV cache state BEFORE calling forward_attention
kc_diag = kv_caches[li]
swa_kv_d, swa_pos_d = kc_diag.get_swa()
print(f" KV: n_comp={kc_diag.n_comp} swa_len={swa_kv_d.shape[0]}", flush=True)
# Gather KV and print
ratio_diag = cr[li] if li < len(cr) else 128
seq_len_d = 0
if kc_diag.n_comp > 0:
if ratio_diag == 4:
# Need to compute indexer top-k first
# Run Q projection to get q_a
pl_diag = prod_lins.get(li)
q_a_d = pl_diag['q_a'].run_from_quantized(x_quant_attn)
q_norm_w_d = layer_w[li].get(f"model.layers.{li}.self_attn.q_a_norm.weight")
if q_norm_w_d is not None:
q_a_quant_d = rmsnorm_quantize_nvfp4(q_a_d, q_norm_w_d.to(dev, torch.float32))
q_a_d = dequantize_nvfp4(q_a_quant_d.x_fp4, q_a_quant_d.x_sf, q_a_quant_d.gsa)
topk_idx_d = None
if indexers.get(li) is not None:
topk_idx_d = indexers[li].forward(q_a_d, x_normed, kc_diag, pos, layer_idx=li)
if topk_idx_d is not None:
tk_d = topk_idx_d[0].clamp(0, kc_diag.n_comp - 1).int()
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_selective(tk_d)
print(f" CSA topk: {tk_d.tolist()[:10]}", flush=True)
else:
kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_swa_only()
elif ratio_diag > 4:
@@ -298,11 +308,16 @@ def main():
print(f" Gathered KV: seq_len={seq_len_d} |nope_fp8|={nope_max:.2f} |nope_scale|={scale_max:.6f} |rope_bf16|={rope_max:.2f}", flush=True)
nope_dequant_max = (kv_nope_fp8_d.view(torch.float8_e4m3fn).float() * kv_nope_scale_d.unsqueeze(-1).float()).abs().max().item()
print(f" |nope_dequant_max|={nope_dequant_max:.4f}", flush=True)
# Now run FMHA
F_attn_d, q_a_d = forward_attention(
x_normed, layer_w[li], li, cfg, *rope_caches[gpu],
kv_caches[li], pos, compressors.get(li), indexers.get(li), prod_lins.get(li),
x_quant=x_quant_attn)
print(f" |F_attn|={F_attn_d.abs().max().item():.2f}", flush=True)
# Check if Q heads are reasonable
q_heads_diag = pl_diag['q_b'].run_from_quantized(rmsnorm_quantize_nvfp4(q_a_d, layer_w[li].get(f"model.layers.{li}.self_attn.q_a_norm.weight").to(dev, torch.float32)))
q_heads_diag = unweighted_rmsnorm(q_heads_diag).bfloat16()
print(f" |Q_heads|={q_heads_diag.abs().max().item():.4f}", flush=True)
X_mid_d = attn_mhc_d.post_block(X_diag, F_attn_d, ctx_a_d)
print(f" |X_mid|={X_mid_d.abs().max().item():.2f} B_l_row=[{B_l_a.sum(-1).min().item():.4f},{B_l_a.sum(-1).max().item():.4f}] C_l=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]", flush=True)
A_l_f, B_l_f, C_l_f = ffn_mhc_d._dynamic_params(X_mid_d)