diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index ad92d7b29..e3864a71d 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -715,7 +715,7 @@ def test_mixtral_moe( # need to override the forward context for unittests, otherwise it assumes # we're running the model forward pass (the model specified in vllm_config) - get_forward_context().remaining_moe_layers = None + get_forward_context().all_moe_layers = None # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 3f94fa8b3..f3354d18c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -597,6 +597,10 @@ class CompilationConfig: Map from layer name to layer objects that need to be accessed outside model code, e.g., Attention, FusedMOE when dp_size>1.""" + static_all_moe_layers: list[str] = field(default_factory=list, init=False) + """The names of all the MOE layers in the model + """ + # Attention ops; used for piecewise cudagraphs # Use PyTorch operator format: "namespace::name" _attention_ops: ClassVar[list[str]] = [ diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 301834d19..731c45fbb 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -217,9 +217,11 @@ class ForwardContext: # the graph. # # The workaround is to store a list of the strings that each of those - # custom ops needs, in reverse order, in the ForwardContext. + # custom ops needs in the ForwardContext (all_moe_layers) + # as well as a counter (moe_layer_index). # The ForwardContext object is alive for the duration of the forward pass. - # When the custom op needs the string, pop the string from this list. + # When the custom op needs a layer string, get the next string + # from all_moe_layers and increment the counter. # # This assumes that the custom operators will always be executed in # order and that torch.compile will not try to reorder these @@ -233,7 +235,8 @@ class ForwardContext: # # If this value is None (like in some tests), then we end up baking the string # into the graph. Otherwise, the moe custom ops will pop a string from this list. - remaining_moe_layers: list[str] | None = None + all_moe_layers: list[str] | None = None + moe_layer_index: int = 0 additional_kwargs: dict[str, Any] = field(default_factory=dict) @@ -271,17 +274,9 @@ def create_forward_context( additional_kwargs: dict[str, Any] | None = None, skip_compiled: bool = False, ): - no_compile_layers = vllm_config.compilation_config.static_forward_context - from vllm.model_executor.layers.fused_moe.layer import FusedMoE - - remaining_moe_layers = [ - name for name, layer in no_compile_layers.items() if isinstance(layer, FusedMoE) - ] - remaining_moe_layers.reverse() - return ForwardContext( - no_compile_layers=no_compile_layers, - remaining_moe_layers=remaining_moe_layers, + no_compile_layers=vllm_config.compilation_config.static_forward_context, + all_moe_layers=vllm_config.compilation_config.static_all_moe_layers, virtual_engine=virtual_engine, attn_metadata=attn_metadata, slot_mapping=slot_mapping or {}, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e38275004..538089882 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -407,6 +407,7 @@ class FusedMoE(CustomOp): if prefix in compilation_config.static_forward_context: raise ValueError("Duplicate layer name: {}".format(prefix)) compilation_config.static_forward_context[prefix] = self + compilation_config.static_all_moe_layers.append(prefix) self.layer_name = prefix self.enable_eplb = enable_eplb @@ -1566,7 +1567,7 @@ class FusedMoE(CustomOp): # Can be unavailable or None in unittests if ( is_forward_context_available() - and get_forward_context().remaining_moe_layers is not None + and get_forward_context().all_moe_layers is not None ): return "from_forward_context" return self.layer_name @@ -1987,13 +1988,17 @@ class FusedMoE(CustomOp): def get_layer_from_name(layer_name: str) -> FusedMoE: forward_context: ForwardContext = get_forward_context() if layer_name == "from_forward_context": - if not forward_context.remaining_moe_layers: + all_moe_layers = forward_context.all_moe_layers + assert all_moe_layers is not None + moe_layer_index = forward_context.moe_layer_index + if moe_layer_index >= len(all_moe_layers): raise AssertionError( - "We expected the number of MOE layers in `remaining_moe_layers` " + "We expected the number of MOE layers in `all_moe_layers` " "to be equal to the number of " "{vllm.moe_forward, vllm.moe_forward_shared} calls." ) - layer_name = forward_context.remaining_moe_layers.pop() + layer_name = all_moe_layers[moe_layer_index] + forward_context.moe_layer_index += 1 self = cast(FusedMoE, forward_context.no_compile_layers[layer_name]) return self