[torch.compile] Improve Cold Start for MoEs (#32805)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-01-22 10:44:40 -05:00
committed by GitHub
parent 15e302dfce
commit 654a71fc3c
3 changed files with 71 additions and 9 deletions

View File

@@ -23,7 +23,7 @@ from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.fused_moe import (
fused_topk,
)
@@ -713,6 +713,10 @@ def test_mixtral_moe(
vllm_moe.experts.quant_method.process_weights_after_loading(vllm_moe.experts)
# need to override the forward context for unittests, otherwise it assumes
# we're running the model forward pass (the model specified in vllm_config)
get_forward_context().remaining_moe_layers = None
# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs)