diff --git a/tests/unit/test_part_a_decode_diagnostics.py b/tests/unit/test_part_a_decode_diagnostics.py index 99e2d6e3..fff125f7 100644 --- a/tests/unit/test_part_a_decode_diagnostics.py +++ b/tests/unit/test_part_a_decode_diagnostics.py @@ -269,20 +269,30 @@ def main(): X_diag, A_l_a, attn_norms.get(li).to(dev, torch.float32)) x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa) print(f" |x_normed|={x_normed.abs().max().item():.2f} gsa={x_quant_attn.gsa}", flush=True) - # Print KV cache state + # Print KV cache state BEFORE calling forward_attention kc_diag = kv_caches[li] swa_kv_d, swa_pos_d = kc_diag.get_swa() print(f" KV: n_comp={kc_diag.n_comp} swa_len={swa_kv_d.shape[0]}", flush=True) # Gather KV and print ratio_diag = cr[li] if li < len(cr) else 128 + seq_len_d = 0 if kc_diag.n_comp > 0: if ratio_diag == 4: + # Need to compute indexer top-k first + # Run Q projection to get q_a + pl_diag = prod_lins.get(li) + q_a_d = pl_diag['q_a'].run_from_quantized(x_quant_attn) + q_norm_w_d = layer_w[li].get(f"model.layers.{li}.self_attn.q_a_norm.weight") + if q_norm_w_d is not None: + q_a_quant_d = rmsnorm_quantize_nvfp4(q_a_d, q_norm_w_d.to(dev, torch.float32)) + q_a_d = dequantize_nvfp4(q_a_quant_d.x_fp4, q_a_quant_d.x_sf, q_a_quant_d.gsa) topk_idx_d = None if indexers.get(li) is not None: topk_idx_d = indexers[li].forward(q_a_d, x_normed, kc_diag, pos, layer_idx=li) if topk_idx_d is not None: tk_d = topk_idx_d[0].clamp(0, kc_diag.n_comp - 1).int() kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_selective(tk_d) + print(f" CSA topk: {tk_d.tolist()[:10]}", flush=True) else: kv_nope_fp8_d, kv_nope_scale_d, kv_rope_bf16_d = kc_diag.gather_mixed_swa_only() elif ratio_diag > 4: @@ -298,11 +308,16 @@ def main(): print(f" Gathered KV: seq_len={seq_len_d} |nope_fp8|={nope_max:.2f} |nope_scale|={scale_max:.6f} |rope_bf16|={rope_max:.2f}", flush=True) nope_dequant_max = (kv_nope_fp8_d.view(torch.float8_e4m3fn).float() * kv_nope_scale_d.unsqueeze(-1).float()).abs().max().item() print(f" |nope_dequant_max|={nope_dequant_max:.4f}", flush=True) + # Now run FMHA F_attn_d, q_a_d = forward_attention( x_normed, layer_w[li], li, cfg, *rope_caches[gpu], kv_caches[li], pos, compressors.get(li), indexers.get(li), prod_lins.get(li), x_quant=x_quant_attn) print(f" |F_attn|={F_attn_d.abs().max().item():.2f}", flush=True) + # Check if Q heads are reasonable + q_heads_diag = pl_diag['q_b'].run_from_quantized(rmsnorm_quantize_nvfp4(q_a_d, layer_w[li].get(f"model.layers.{li}.self_attn.q_a_norm.weight").to(dev, torch.float32))) + q_heads_diag = unweighted_rmsnorm(q_heads_diag).bfloat16() + print(f" |Q_heads|={q_heads_diag.abs().max().item():.4f}", flush=True) X_mid_d = attn_mhc_d.post_block(X_diag, F_attn_d, ctx_a_d) print(f" |X_mid|={X_mid_d.abs().max().item():.2f} B_l_row=[{B_l_a.sum(-1).min().item():.4f},{B_l_a.sum(-1).max().item():.4f}] C_l=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]", flush=True) A_l_f, B_l_f, C_l_f = ffn_mhc_d._dynamic_params(X_mid_d)