[torch.compile] Improve Cold Start for MoEs (#32805)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -23,7 +23,7 @@ from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
fused_topk,
|
||||
)
|
||||
@@ -713,6 +713,10 @@ def test_mixtral_moe(
|
||||
|
||||
vllm_moe.experts.quant_method.process_weights_after_loading(vllm_moe.experts)
|
||||
|
||||
# need to override the forward context for unittests, otherwise it assumes
|
||||
# we're running the model forward pass (the model specified in vllm_config)
|
||||
get_forward_context().remaining_moe_layers = None
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
||||
@@ -210,6 +210,30 @@ class ForwardContext:
|
||||
# If True, bypass the compiled model call, e.g. by using .forward() directly
|
||||
skip_compiled: bool = False
|
||||
|
||||
# For torch.compile cold start times, we need to avoid hard-coding
|
||||
# any strings into the graph. Right now, the vllm.moe_forward
|
||||
# and vllm.moe_forward_shared custom operators hard-code strings into
|
||||
# the graph.
|
||||
#
|
||||
# The workaround is to store a list of the strings that each of those
|
||||
# custom ops needs, in reverse order, in the ForwardContext.
|
||||
# The ForwardContext object is alive for the duration of the forward pass.
|
||||
# When the custom op needs the string, pop the string from this list.
|
||||
#
|
||||
# This assumes that the custom operators will always be executed in
|
||||
# order and that torch.compile will not try to reorder these
|
||||
# operations with respect to each other.
|
||||
#
|
||||
# TODO(https://github.com/vllm-project/vllm/issues/31985):
|
||||
# There are longer-term solutions, like unwrapping the moe custom operator,
|
||||
# that aren't ready yet.
|
||||
# We could also treat the string as a "symbolic input" to the graph but
|
||||
# the PyTorch-side bits for that aren't ready yet either.
|
||||
#
|
||||
# If this value is None (like in some tests), then we end up baking the string
|
||||
# into the graph. Otherwise, the moe custom ops will pop a string from this list.
|
||||
remaining_moe_layers: list[str] | None = None
|
||||
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -245,8 +269,17 @@ def create_forward_context(
|
||||
additional_kwargs: dict[str, Any] | None = None,
|
||||
skip_compiled: bool = False,
|
||||
):
|
||||
no_compile_layers = vllm_config.compilation_config.static_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
remaining_moe_layers = [
|
||||
name for name, layer in no_compile_layers.items() if isinstance(layer, FusedMoE)
|
||||
]
|
||||
remaining_moe_layers.reverse()
|
||||
|
||||
return ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||
no_compile_layers=no_compile_layers,
|
||||
remaining_moe_layers=remaining_moe_layers,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata,
|
||||
|
||||
@@ -22,7 +22,11 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.forward_context import (
|
||||
ForwardContext,
|
||||
get_forward_context,
|
||||
is_forward_context_available,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -1564,6 +1568,15 @@ class FusedMoE(CustomOp):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
return states
|
||||
|
||||
def encode_layer_name() -> str:
|
||||
# Can be unavailable or None in unittests
|
||||
if (
|
||||
is_forward_context_available()
|
||||
and get_forward_context().remaining_moe_layers is not None
|
||||
):
|
||||
return "from_forward_context"
|
||||
return self.layer_name
|
||||
|
||||
if self.shared_experts is None:
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
@@ -1573,7 +1586,7 @@ class FusedMoE(CustomOp):
|
||||
assert not isinstance(fused_output, tuple)
|
||||
else:
|
||||
fused_output = torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits, self.layer_name
|
||||
hidden_states, router_logits, encode_layer_name()
|
||||
)
|
||||
return reduce_output(fused_output)[..., :og_hidden_states]
|
||||
else:
|
||||
@@ -1586,7 +1599,7 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
else:
|
||||
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
||||
hidden_states, router_logits, self.layer_name
|
||||
hidden_states, router_logits, encode_layer_name()
|
||||
)
|
||||
return (
|
||||
reduce_output(shared_output)[..., :og_hidden_states],
|
||||
@@ -1936,13 +1949,26 @@ class FusedMoE(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
def get_layer_from_name(layer_name: str) -> FusedMoE:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if layer_name == "from_forward_context":
|
||||
if not forward_context.remaining_moe_layers:
|
||||
raise AssertionError(
|
||||
"We expected the number of MOE layers in `remaining_moe_layers` "
|
||||
"to be equal to the number of "
|
||||
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
|
||||
)
|
||||
layer_name = forward_context.remaining_moe_layers.pop()
|
||||
self = cast(FusedMoE, forward_context.no_compile_layers[layer_name])
|
||||
return self
|
||||
|
||||
|
||||
def moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self = get_layer_from_name(layer_name)
|
||||
assert self.shared_experts is None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
@@ -1969,8 +1995,7 @@ def moe_forward_shared(
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self = get_layer_from_name(layer_name)
|
||||
assert self.shared_experts is not None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user