diff --git a/tests/unit/test_part_a_decode_diagnostics.py b/tests/unit/test_part_a_decode_diagnostics.py index fff125f7..b7c8ed3f 100644 --- a/tests/unit/test_part_a_decode_diagnostics.py +++ b/tests/unit/test_part_a_decode_diagnostics.py @@ -269,6 +269,14 @@ def main(): X_diag, A_l_a, attn_norms.get(li).to(dev, torch.float32)) x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa) print(f" |x_normed|={x_normed.abs().max().item():.2f} gsa={x_quant_attn.gsa}", flush=True) + # Run compressor and print raw output + comp_diag = compressors.get(li) + if comp_diag is not None: + comp_kv_d, comp_pos_d, _ = comp_diag.forward(x_normed, pos) + if comp_kv_d is not None: + print(f" Compressor output: |comp_kv|={comp_kv_d.abs().max().item():.2f} shape={tuple(comp_kv_d.shape)}", flush=True) + else: + print(f" Compressor output: None (n_complete=0)", flush=True) # Print KV cache state BEFORE calling forward_attention kc_diag = kv_caches[li] swa_kv_d, swa_pos_d = kc_diag.get_swa()