[BugFix] Fix DP chunking (#34379)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-02-11 23:44:03 -07:00
committed by GitHub
parent b96f7314b4
commit 136b0bfa59

View File

@@ -421,7 +421,7 @@ class DefaultMoERunner(MoERunner):
layer: torch.nn.Module,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
shared_input: torch.Tensor | None,
full_shared_input: torch.Tensor | None,
has_separate_shared_experts: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
@@ -449,6 +449,11 @@ class DefaultMoERunner(MoERunner):
chunk_size = chunk_end - chunk_start
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
router_logits = full_router_logits[chunk_start:chunk_end, :]
shared_input = (
full_shared_input[chunk_start:chunk_end, :]
if full_shared_input is not None
else None
)
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
@@ -476,8 +481,13 @@ class DefaultMoERunner(MoERunner):
staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True)
shared_input = (
shared_input if shared_input is not None else staged_hidden_states
)
# Matrix multiply.
if self.quant_method.is_monolithic:
assert has_separate_shared_experts or self.shared_experts is None
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
x=staged_hidden_states,
@@ -501,7 +511,7 @@ class DefaultMoERunner(MoERunner):
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
shared_output = self.shared_experts(staged_hidden_states)
shared_output = self.shared_experts(shared_input)
final_hidden_states = (
shared_output,