PART A: add raw compressor output diagnostic

This commit is contained in:
2026-06-03 06:56:56 +00:00
parent f2c1b3afd5
commit a682c6adf4

View File

@@ -269,6 +269,14 @@ def main():
X_diag, A_l_a, attn_norms.get(li).to(dev, torch.float32))
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
print(f" |x_normed|={x_normed.abs().max().item():.2f} gsa={x_quant_attn.gsa}", flush=True)
# Run compressor and print raw output
comp_diag = compressors.get(li)
if comp_diag is not None:
comp_kv_d, comp_pos_d, _ = comp_diag.forward(x_normed, pos)
if comp_kv_d is not None:
print(f" Compressor output: |comp_kv|={comp_kv_d.abs().max().item():.2f} shape={tuple(comp_kv_d.shape)}", flush=True)
else:
print(f" Compressor output: None (n_complete=0)", flush=True)
# Print KV cache state BEFORE calling forward_attention
kc_diag = kv_caches[li]
swa_kv_d, swa_pos_d = kc_diag.get_swa()