From 654a71fc3c6d0665b8b4805b0219e35dc485416e Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Thu, 22 Jan 2026 10:44:40 -0500 Subject: [PATCH] [torch.compile] Improve Cold Start for MoEs (#32805) Signed-off-by: Richard Zou --- tests/kernels/moe/test_moe.py | 6 ++- vllm/forward_context.py | 35 ++++++++++++++++- vllm/model_executor/layers/fused_moe/layer.py | 39 +++++++++++++++---- 3 files changed, 71 insertions(+), 9 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index e34c78074..58a4d09d9 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -23,7 +23,7 @@ from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment -from vllm.forward_context import set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.fused_moe import ( fused_topk, ) @@ -713,6 +713,10 @@ def test_mixtral_moe( vllm_moe.experts.quant_method.process_weights_after_loading(vllm_moe.experts) + # need to override the forward context for unittests, otherwise it assumes + # we're running the model forward pass (the model specified in vllm_config) + get_forward_context().remaining_moe_layers = None + # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) vllm_states = vllm_moe.forward(vllm_inputs) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index ed91af44a..a856f6f31 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -210,6 +210,30 @@ class ForwardContext: # If True, bypass the compiled model call, e.g. by using .forward() directly skip_compiled: bool = False + # For torch.compile cold start times, we need to avoid hard-coding + # any strings into the graph. Right now, the vllm.moe_forward + # and vllm.moe_forward_shared custom operators hard-code strings into + # the graph. + # + # The workaround is to store a list of the strings that each of those + # custom ops needs, in reverse order, in the ForwardContext. + # The ForwardContext object is alive for the duration of the forward pass. + # When the custom op needs the string, pop the string from this list. + # + # This assumes that the custom operators will always be executed in + # order and that torch.compile will not try to reorder these + # operations with respect to each other. + # + # TODO(https://github.com/vllm-project/vllm/issues/31985): + # There are longer-term solutions, like unwrapping the moe custom operator, + # that aren't ready yet. + # We could also treat the string as a "symbolic input" to the graph but + # the PyTorch-side bits for that aren't ready yet either. + # + # If this value is None (like in some tests), then we end up baking the string + # into the graph. Otherwise, the moe custom ops will pop a string from this list. + remaining_moe_layers: list[str] | None = None + additional_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self): @@ -245,8 +269,17 @@ def create_forward_context( additional_kwargs: dict[str, Any] | None = None, skip_compiled: bool = False, ): + no_compile_layers = vllm_config.compilation_config.static_forward_context + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + remaining_moe_layers = [ + name for name, layer in no_compile_layers.items() if isinstance(layer, FusedMoE) + ] + remaining_moe_layers.reverse() + return ForwardContext( - no_compile_layers=vllm_config.compilation_config.static_forward_context, + no_compile_layers=no_compile_layers, + remaining_moe_layers=remaining_moe_layers, virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index dd0a837c8..3b85bc7a2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -22,7 +22,11 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState -from vllm.forward_context import ForwardContext, get_forward_context +from vllm.forward_context import ( + ForwardContext, + get_forward_context, + is_forward_context_available, +) from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe.config import ( @@ -1564,6 +1568,15 @@ class FusedMoE(CustomOp): states = self.maybe_all_reduce_tensor_model_parallel(states) return states + def encode_layer_name() -> str: + # Can be unavailable or None in unittests + if ( + is_forward_context_available() + and get_forward_context().remaining_moe_layers is not None + ): + return "from_forward_context" + return self.layer_name + if self.shared_experts is None: if current_platform.is_tpu() or current_platform.is_cpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we @@ -1573,7 +1586,7 @@ class FusedMoE(CustomOp): assert not isinstance(fused_output, tuple) else: fused_output = torch.ops.vllm.moe_forward( - hidden_states, router_logits, self.layer_name + hidden_states, router_logits, encode_layer_name() ) return reduce_output(fused_output)[..., :og_hidden_states] else: @@ -1586,7 +1599,7 @@ class FusedMoE(CustomOp): ) else: shared_output, fused_output = torch.ops.vllm.moe_forward_shared( - hidden_states, router_logits, self.layer_name + hidden_states, router_logits, encode_layer_name() ) return ( reduce_output(shared_output)[..., :og_hidden_states], @@ -1936,13 +1949,26 @@ class FusedMoE(CustomOp): return s +def get_layer_from_name(layer_name: str) -> FusedMoE: + forward_context: ForwardContext = get_forward_context() + if layer_name == "from_forward_context": + if not forward_context.remaining_moe_layers: + raise AssertionError( + "We expected the number of MOE layers in `remaining_moe_layers` " + "to be equal to the number of " + "{vllm.moe_forward, vllm.moe_forward_shared} calls." + ) + layer_name = forward_context.remaining_moe_layers.pop() + self = cast(FusedMoE, forward_context.no_compile_layers[layer_name]) + return self + + def moe_forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, layer_name: str, ) -> torch.Tensor: - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] + self = get_layer_from_name(layer_name) assert self.shared_experts is None return self.forward_impl(hidden_states, router_logits) @@ -1969,8 +1995,7 @@ def moe_forward_shared( router_logits: torch.Tensor, layer_name: str, ) -> tuple[torch.Tensor, torch.Tensor]: - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] + self = get_layer_from_name(layer_name) assert self.shared_experts is not None return self.forward_impl(hidden_states, router_logits)