[torch.compile] Don't do the fast moe cold start optimization if there is speculative decoding (#33624)

Signed-off-by: Richard Zou <zou3519@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
(cherry picked from commit 5eac9a1b34)
This commit is contained in:
Richard Zou
2026-02-02 19:38:49 -08:00
committed by khluu
parent 611b18757e
commit e4bf6ed90d
2 changed files with 32 additions and 1 deletions

View File

@@ -581,6 +581,24 @@ class CompilationConfig:
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
fast_moe_cold_start = True
"""Optimization for fast MOE cold start.
This is a bit of a hack that assumes that:
1. the only decoder forward pass being run is the current model
2. the decoder forward pass runs all of the MOEs in the order in which they
are initialized
When the above two conditions hold, this option greatly decreases cold start
time for MOE models.
If the above two conditions don't hold, then this option will lead to silent
incorrectness. The only condition in which this doesn't hold is speculative
decoding, where there is a draft model that may have MOEs in them.
NB: We're working on a longer-term solution that doesn't need these assumptions.
"""
# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
"""custom ops that are enabled"""

View File

@@ -274,9 +274,22 @@ def create_forward_context(
additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
):
if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None:
all_moe_layers = vllm_config.compilation_config.static_all_moe_layers
else:
logger.warning_once(
"vllm_config.compilation_config.fast_moe_cold_start is not "
"compatible with speculative decoding so we are ignoring "
"fast_moe_cold_start."
)
all_moe_layers = None
else:
all_moe_layers = None
return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context,
all_moe_layers=vllm_config.compilation_config.static_all_moe_layers,
all_moe_layers=all_moe_layers,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping or {},