diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index 3ee66e87..71c2f76f 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -172,6 +172,7 @@ def main(): intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS, top_k=TOP_K, device=DEVICE, ) + runner._ensure_stacked() # Just use the runner's scale assembly l1_gsa = torch.full((NUM_EXPERTS,), l1_gs, dtype=torch.float32, device=DEVICE) l1_scale_a = runner._assemble_scales_cudagraph_safe(