PART A: add raw compressor output diagnostic
This commit is contained in:
@@ -269,6 +269,14 @@ def main():
|
||||
X_diag, A_l_a, attn_norms.get(li).to(dev, torch.float32))
|
||||
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
||||
print(f" |x_normed|={x_normed.abs().max().item():.2f} gsa={x_quant_attn.gsa}", flush=True)
|
||||
# Run compressor and print raw output
|
||||
comp_diag = compressors.get(li)
|
||||
if comp_diag is not None:
|
||||
comp_kv_d, comp_pos_d, _ = comp_diag.forward(x_normed, pos)
|
||||
if comp_kv_d is not None:
|
||||
print(f" Compressor output: |comp_kv|={comp_kv_d.abs().max().item():.2f} shape={tuple(comp_kv_d.shape)}", flush=True)
|
||||
else:
|
||||
print(f" Compressor output: None (n_complete=0)", flush=True)
|
||||
# Print KV cache state BEFORE calling forward_attention
|
||||
kc_diag = kv_caches[li]
|
||||
swa_kv_d, swa_pos_d = kc_diag.get_swa()
|
||||
|
||||
Reference in New Issue
Block a user