[torch.compile] Improve Cold Start for MoEs (#32805)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -23,7 +23,7 @@ from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
fused_topk,
|
||||
)
|
||||
@@ -713,6 +713,10 @@ def test_mixtral_moe(
|
||||
|
||||
vllm_moe.experts.quant_method.process_weights_after_loading(vllm_moe.experts)
|
||||
|
||||
# need to override the forward context for unittests, otherwise it assumes
|
||||
# we're running the model forward pass (the model specified in vllm_config)
|
||||
get_forward_context().remaining_moe_layers = None
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
||||
Reference in New Issue
Block a user