[torch.compile] Speed up MOE handling in forward_context (#33184)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -715,7 +715,7 @@ def test_mixtral_moe(
|
||||
|
||||
# need to override the forward context for unittests, otherwise it assumes
|
||||
# we're running the model forward pass (the model specified in vllm_config)
|
||||
get_forward_context().remaining_moe_layers = None
|
||||
get_forward_context().all_moe_layers = None
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
|
||||
@@ -597,6 +597,10 @@ class CompilationConfig:
|
||||
Map from layer name to layer objects that need to be accessed outside
|
||||
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
||||
|
||||
static_all_moe_layers: list[str] = field(default_factory=list, init=False)
|
||||
"""The names of all the MOE layers in the model
|
||||
"""
|
||||
|
||||
# Attention ops; used for piecewise cudagraphs
|
||||
# Use PyTorch operator format: "namespace::name"
|
||||
_attention_ops: ClassVar[list[str]] = [
|
||||
|
||||
@@ -217,9 +217,11 @@ class ForwardContext:
|
||||
# the graph.
|
||||
#
|
||||
# The workaround is to store a list of the strings that each of those
|
||||
# custom ops needs, in reverse order, in the ForwardContext.
|
||||
# custom ops needs in the ForwardContext (all_moe_layers)
|
||||
# as well as a counter (moe_layer_index).
|
||||
# The ForwardContext object is alive for the duration of the forward pass.
|
||||
# When the custom op needs the string, pop the string from this list.
|
||||
# When the custom op needs a layer string, get the next string
|
||||
# from all_moe_layers and increment the counter.
|
||||
#
|
||||
# This assumes that the custom operators will always be executed in
|
||||
# order and that torch.compile will not try to reorder these
|
||||
@@ -233,7 +235,8 @@ class ForwardContext:
|
||||
#
|
||||
# If this value is None (like in some tests), then we end up baking the string
|
||||
# into the graph. Otherwise, the moe custom ops will pop a string from this list.
|
||||
remaining_moe_layers: list[str] | None = None
|
||||
all_moe_layers: list[str] | None = None
|
||||
moe_layer_index: int = 0
|
||||
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -271,17 +274,9 @@ def create_forward_context(
|
||||
additional_kwargs: dict[str, Any] | None = None,
|
||||
skip_compiled: bool = False,
|
||||
):
|
||||
no_compile_layers = vllm_config.compilation_config.static_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
remaining_moe_layers = [
|
||||
name for name, layer in no_compile_layers.items() if isinstance(layer, FusedMoE)
|
||||
]
|
||||
remaining_moe_layers.reverse()
|
||||
|
||||
return ForwardContext(
|
||||
no_compile_layers=no_compile_layers,
|
||||
remaining_moe_layers=remaining_moe_layers,
|
||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||
all_moe_layers=vllm_config.compilation_config.static_all_moe_layers,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mapping=slot_mapping or {},
|
||||
|
||||
@@ -407,6 +407,7 @@ class FusedMoE(CustomOp):
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
compilation_config.static_all_moe_layers.append(prefix)
|
||||
self.layer_name = prefix
|
||||
|
||||
self.enable_eplb = enable_eplb
|
||||
@@ -1566,7 +1567,7 @@ class FusedMoE(CustomOp):
|
||||
# Can be unavailable or None in unittests
|
||||
if (
|
||||
is_forward_context_available()
|
||||
and get_forward_context().remaining_moe_layers is not None
|
||||
and get_forward_context().all_moe_layers is not None
|
||||
):
|
||||
return "from_forward_context"
|
||||
return self.layer_name
|
||||
@@ -1987,13 +1988,17 @@ class FusedMoE(CustomOp):
|
||||
def get_layer_from_name(layer_name: str) -> FusedMoE:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if layer_name == "from_forward_context":
|
||||
if not forward_context.remaining_moe_layers:
|
||||
all_moe_layers = forward_context.all_moe_layers
|
||||
assert all_moe_layers is not None
|
||||
moe_layer_index = forward_context.moe_layer_index
|
||||
if moe_layer_index >= len(all_moe_layers):
|
||||
raise AssertionError(
|
||||
"We expected the number of MOE layers in `remaining_moe_layers` "
|
||||
"We expected the number of MOE layers in `all_moe_layers` "
|
||||
"to be equal to the number of "
|
||||
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
|
||||
)
|
||||
layer_name = forward_context.remaining_moe_layers.pop()
|
||||
layer_name = all_moe_layers[moe_layer_index]
|
||||
forward_context.moe_layer_index += 1
|
||||
self = cast(FusedMoE, forward_context.no_compile_layers[layer_name])
|
||||
return self
|
||||
|
||||
|
||||
Reference in New Issue
Block a user