diff --git a/tests/test_compile_custom_op.py b/tests/test_compile_custom_op.py index d6be877e..d4580c38 100644 --- a/tests/test_compile_custom_op.py +++ b/tests/test_compile_custom_op.py @@ -136,6 +136,12 @@ def main(): topk_ids = torch.tensor([[0, 1]] * 4, dtype=torch.int32, device=DEVICE) topk_weights = torch.tensor([[0.6, 0.4]] * 4, dtype=torch.float32, device=DEVICE) + # 1. Warmup: compute activation global scales + print("\n[0] Computing activation global scales (warmup)...") + runner.compute_activation_global_scales(hidden_states, topk_weights, topk_ids) + print(f" L1 gs: {runner._l1_activation_global_scale:.6f}") + print(f" L2 gs: {runner._l2_activation_global_scale:.6f}") + # 1. Eager mode (baseline) print("\n[1/2] Running eager mode (baseline)...") runner._ensure_stacked()