Call _ensure_stacked() before using runner buffers
This commit is contained in:
@@ -172,6 +172,7 @@ def main():
|
||||
intermediate_size=INTERMEDIATE_SIZE, max_num_tokens=NUM_TOKENS,
|
||||
top_k=TOP_K, device=DEVICE,
|
||||
)
|
||||
runner._ensure_stacked()
|
||||
# Just use the runner's scale assembly
|
||||
l1_gsa = torch.full((NUM_EXPERTS,), l1_gs, dtype=torch.float32, device=DEVICE)
|
||||
l1_scale_a = runner._assemble_scales_cudagraph_safe(
|
||||
|
||||
Reference in New Issue
Block a user