From 4df841fe7538cb8de281b9d78e37ba51ac35b5da Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Sun, 8 Feb 2026 13:42:56 -0500 Subject: [PATCH] [torch.compile] Add an option to force-enable the MOE cold start optimization (#33735) Signed-off-by: Richard Zou --- vllm/config/compilation.py | 12 +++++++++--- vllm/config/vllm.py | 8 ++++++++ vllm/forward_context.py | 10 +--------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 2b4ce27a3..fb7a1466b 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -593,7 +593,7 @@ class CompilationConfig: local_cache_dir: str = field(default=None, init=False) # type: ignore """local cache dir for each rank""" - fast_moe_cold_start = True + fast_moe_cold_start: bool | None = None """Optimization for fast MOE cold start. This is a bit of a hack that assumes that: @@ -604,8 +604,14 @@ class CompilationConfig: When the above two conditions hold, this option greatly decreases cold start time for MOE models. - If the above two conditions don't hold, then this option will lead to silent - incorrectness. The only condition in which this doesn't hold is speculative + The options are: + - True: optimization is always on + - False: optimization is always off + - None: optimization is on usually but off for speculative decoding + + If conditions 1&2 don't hold then this option will lead to silent + incorrectness. + The only condition in which this doesn't hold is speculative decoding, where there is a draft model that may have MOEs in them. NB: We're working on a longer-term solution that doesn't need these assumptions. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f0e304d9d..c1ef8e6aa 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -806,6 +806,14 @@ class VllmConfig: else: self.compilation_config.custom_ops.append("+rms_norm") + if self.compilation_config.fast_moe_cold_start is None: + # resolve default behavior: try to be as safe as possible + # this config is unsafe if any spec decoding draft model has a MOE. + # We'll conservatively turn it off if we see spec decoding. + self.compilation_config.fast_moe_cold_start = ( + self.speculative_config is None + ) + if current_platform.support_static_graph_mode(): # if cudagraph_mode has full cudagraphs, we need to check support if model_config := self.model_config: diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e308c05bc..d357c8929 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -287,15 +287,7 @@ def create_forward_context( skip_compiled: bool = False, ): if vllm_config.compilation_config.fast_moe_cold_start: - if vllm_config.speculative_config is None: - all_moe_layers = vllm_config.compilation_config.static_all_moe_layers - else: - logger.warning_once( - "vllm_config.compilation_config.fast_moe_cold_start is not " - "compatible with speculative decoding so we are ignoring " - "fast_moe_cold_start." - ) - all_moe_layers = None + all_moe_layers = vllm_config.compilation_config.static_all_moe_layers else: all_moe_layers = None