[torch.compile] Speed up MOE handling in forward_context (#33184)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-01-27 18:17:54 -05:00
committed by GitHub
parent 3a6d5cbefd
commit d9aa39a3bb
4 changed files with 22 additions and 18 deletions

View File

@@ -407,6 +407,7 @@ class FusedMoE(CustomOp):
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
compilation_config.static_all_moe_layers.append(prefix)
self.layer_name = prefix
self.enable_eplb = enable_eplb
@@ -1566,7 +1567,7 @@ class FusedMoE(CustomOp):
# Can be unavailable or None in unittests
if (
is_forward_context_available()
and get_forward_context().remaining_moe_layers is not None
and get_forward_context().all_moe_layers is not None
):
return "from_forward_context"
return self.layer_name
@@ -1987,13 +1988,17 @@ class FusedMoE(CustomOp):
def get_layer_from_name(layer_name: str) -> FusedMoE:
forward_context: ForwardContext = get_forward_context()
if layer_name == "from_forward_context":
if not forward_context.remaining_moe_layers:
all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index
if moe_layer_index >= len(all_moe_layers):
raise AssertionError(
"We expected the number of MOE layers in `remaining_moe_layers` "
"We expected the number of MOE layers in `all_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name = forward_context.remaining_moe_layers.pop()
layer_name = all_moe_layers[moe_layer_index]
forward_context.moe_layer_index += 1
self = cast(FusedMoE, forward_context.no_compile_layers[layer_name])
return self