[BugFix] Fix DP chunking (#34379)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -421,7 +421,7 @@ class DefaultMoERunner(MoERunner):
|
||||
layer: torch.nn.Module,
|
||||
full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor,
|
||||
shared_input: torch.Tensor | None,
|
||||
full_shared_input: torch.Tensor | None,
|
||||
has_separate_shared_experts: bool,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.batched_hidden_states is not None
|
||||
@@ -449,6 +449,11 @@ class DefaultMoERunner(MoERunner):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||
shared_input = (
|
||||
full_shared_input[chunk_start:chunk_end, :]
|
||||
if full_shared_input is not None
|
||||
else None
|
||||
)
|
||||
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
@@ -476,8 +481,13 @@ class DefaultMoERunner(MoERunner):
|
||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||
|
||||
shared_input = (
|
||||
shared_input if shared_input is not None else staged_hidden_states
|
||||
)
|
||||
|
||||
# Matrix multiply.
|
||||
if self.quant_method.is_monolithic:
|
||||
assert has_separate_shared_experts or self.shared_experts is None
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=staged_hidden_states,
|
||||
@@ -501,7 +511,7 @@ class DefaultMoERunner(MoERunner):
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
shared_output = self.shared_experts(staged_hidden_states)
|
||||
shared_output = self.shared_experts(shared_input)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
|
||||
Reference in New Issue
Block a user