diff --git a/tests/debug_wo_a3.py b/tests/debug_wo_a3.py index b276179c..8245f422 100644 --- a/tests/debug_wo_a3.py +++ b/tests/debug_wo_a3.py @@ -45,7 +45,7 @@ runner._ensure_initialized() # Compute activation gs with torch.no_grad(): - _, _, gs = quantize_to_nvfp4(o_g.reshape(T, GI)[:1]) + _, _, gs = quantize_to_nvfp4(o_g[:, 0, :]) # use first group's activation print(f"\nActivation gs from sample: {gs:.6f}") print(f"Runner gs: {runner._activation_global_scale:.6f}")