[torch.compile] Don't do the fast moe cold start optimization if there is speculative decoding (#33624)
Signed-off-by: Richard Zou <zou3519@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
(cherry picked from commit 5eac9a1b34)
This commit is contained in:
@@ -581,6 +581,24 @@ class CompilationConfig:
|
||||
local_cache_dir: str = field(default=None, init=False) # type: ignore
|
||||
"""local cache dir for each rank"""
|
||||
|
||||
fast_moe_cold_start = True
|
||||
"""Optimization for fast MOE cold start.
|
||||
|
||||
This is a bit of a hack that assumes that:
|
||||
1. the only decoder forward pass being run is the current model
|
||||
2. the decoder forward pass runs all of the MOEs in the order in which they
|
||||
are initialized
|
||||
|
||||
When the above two conditions hold, this option greatly decreases cold start
|
||||
time for MOE models.
|
||||
|
||||
If the above two conditions don't hold, then this option will lead to silent
|
||||
incorrectness. The only condition in which this doesn't hold is speculative
|
||||
decoding, where there is a draft model that may have MOEs in them.
|
||||
|
||||
NB: We're working on a longer-term solution that doesn't need these assumptions.
|
||||
"""
|
||||
|
||||
# keep track of enabled and disabled custom ops
|
||||
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
|
||||
"""custom ops that are enabled"""
|
||||
|
||||
@@ -274,9 +274,22 @@ def create_forward_context(
|
||||
additional_kwargs: dict[str, Any] | None = None,
|
||||
skip_compiled: bool = False,
|
||||
):
|
||||
if vllm_config.compilation_config.fast_moe_cold_start:
|
||||
if vllm_config.speculative_config is None:
|
||||
all_moe_layers = vllm_config.compilation_config.static_all_moe_layers
|
||||
else:
|
||||
logger.warning_once(
|
||||
"vllm_config.compilation_config.fast_moe_cold_start is not "
|
||||
"compatible with speculative decoding so we are ignoring "
|
||||
"fast_moe_cold_start."
|
||||
)
|
||||
all_moe_layers = None
|
||||
else:
|
||||
all_moe_layers = None
|
||||
|
||||
return ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||
all_moe_layers=vllm_config.compilation_config.static_all_moe_layers,
|
||||
all_moe_layers=all_moe_layers,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mapping=slot_mapping or {},
|
||||
|
||||
Reference in New Issue
Block a user