Call _ensure_stacked() before using runner buffers

This commit is contained in:
2026-05-17 21:22:30 +00:00
parent 1acf01fc1a
commit b7acac5e4e

View File

@@ -172,6 +172,7 @@ def main():
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS,
top_k=TOP_K, device=DEVICE,
)
runner._ensure_stacked()
# Just use the runner's scale assembly
l1_gsa = torch.full((NUM_EXPERTS,), l1_gs, dtype=torch.float32, device=DEVICE)
l1_scale_a = runner._assemble_scales_cudagraph_safe(