diff --git a/tests/test_warmup_gs.py b/tests/test_warmup_gs.py index 6c8b2c6b..b46a7a6b 100644 --- a/tests/test_warmup_gs.py +++ b/tests/test_warmup_gs.py @@ -154,12 +154,9 @@ def main(): [w.clone() for w in l2_fp4], [w.clone() for w in l2_sf], list(l2_gs), ) - l1_gs_val, l2_gs_val = warmup_compute_gs(runner, hidden_states, topk_weights, topk_ids) - print(f" Warmup L1 gs: {l1_gs_val:.10f}") - print(f" Warmup L2 gs: {l2_gs_val:.10f}") + # Use the runner's built-in warmup method + runner.compute_activation_global_scales(hidden_states.clone(), topk_weights, topk_ids) - runner._l1_activation_global_scale = l1_gs_val - runner._l2_activation_global_scale = l2_gs_val result = runner.run(hidden_states.clone(), topk_weights, topk_ids) cos = torch.nn.functional.cosine_similarity(