[BugFix] Fix DP chunking (#34379)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -421,7 +421,7 @@ class DefaultMoERunner(MoERunner):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
full_hidden_states: torch.Tensor,
|
full_hidden_states: torch.Tensor,
|
||||||
full_router_logits: torch.Tensor,
|
full_router_logits: torch.Tensor,
|
||||||
shared_input: torch.Tensor | None,
|
full_shared_input: torch.Tensor | None,
|
||||||
has_separate_shared_experts: bool,
|
has_separate_shared_experts: bool,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert self.batched_hidden_states is not None
|
assert self.batched_hidden_states is not None
|
||||||
@@ -449,6 +449,11 @@ class DefaultMoERunner(MoERunner):
|
|||||||
chunk_size = chunk_end - chunk_start
|
chunk_size = chunk_end - chunk_start
|
||||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||||
|
shared_input = (
|
||||||
|
full_shared_input[chunk_start:chunk_end, :]
|
||||||
|
if full_shared_input is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
assert self.batched_hidden_states is not None
|
assert self.batched_hidden_states is not None
|
||||||
assert self.batched_router_logits is not None
|
assert self.batched_router_logits is not None
|
||||||
@@ -476,8 +481,13 @@ class DefaultMoERunner(MoERunner):
|
|||||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||||
|
|
||||||
|
shared_input = (
|
||||||
|
shared_input if shared_input is not None else staged_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
if self.quant_method.is_monolithic:
|
if self.quant_method.is_monolithic:
|
||||||
|
assert has_separate_shared_experts or self.shared_experts is None
|
||||||
final_hidden_states = self.quant_method.apply_monolithic(
|
final_hidden_states = self.quant_method.apply_monolithic(
|
||||||
layer=layer,
|
layer=layer,
|
||||||
x=staged_hidden_states,
|
x=staged_hidden_states,
|
||||||
@@ -501,7 +511,7 @@ class DefaultMoERunner(MoERunner):
|
|||||||
assert not isinstance(final_hidden_states, tuple)
|
assert not isinstance(final_hidden_states, tuple)
|
||||||
assert self.shared_experts is not None
|
assert self.shared_experts is not None
|
||||||
|
|
||||||
shared_output = self.shared_experts(staged_hidden_states)
|
shared_output = self.shared_experts(shared_input)
|
||||||
|
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
shared_output,
|
shared_output,
|
||||||
|
|||||||
Reference in New Issue
Block a user