From 136b0bfa59377ed2bbd3b3716036a96267cfe80b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 11 Feb 2026 23:44:03 -0700 Subject: [PATCH] [BugFix] Fix DP chunking (#34379) Signed-off-by: Lucas Wilkinson Signed-off-by: Bill Nell Co-authored-by: Bill Nell --- .../layers/fused_moe/runner/default_moe_runner.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index b265cbb41..e68d35b31 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -421,7 +421,7 @@ class DefaultMoERunner(MoERunner): layer: torch.nn.Module, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor, - shared_input: torch.Tensor | None, + full_shared_input: torch.Tensor | None, has_separate_shared_experts: bool, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.batched_hidden_states is not None @@ -449,6 +449,11 @@ class DefaultMoERunner(MoERunner): chunk_size = chunk_end - chunk_start hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] + shared_input = ( + full_shared_input[chunk_start:chunk_end, :] + if full_shared_input is not None + else None + ) assert self.batched_hidden_states is not None assert self.batched_router_logits is not None @@ -476,8 +481,13 @@ class DefaultMoERunner(MoERunner): staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) + shared_input = ( + shared_input if shared_input is not None else staged_hidden_states + ) + # Matrix multiply. if self.quant_method.is_monolithic: + assert has_separate_shared_experts or self.shared_experts is None final_hidden_states = self.quant_method.apply_monolithic( layer=layer, x=staged_hidden_states, @@ -501,7 +511,7 @@ class DefaultMoERunner(MoERunner): assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None - shared_output = self.shared_experts(staged_hidden_states) + shared_output = self.shared_experts(shared_input) final_hidden_states = ( shared_output,