[BugFix] 1D query fix for MoE models (#3597)

This commit is contained in:
Nick Hill
2024-03-24 16:00:16 -07:00
committed by GitHub
parent af9e53496f
commit 41deac4a3d
4 changed files with 15 additions and 15 deletions

View File

@@ -150,11 +150,11 @@ class DeepseekMoE(nn.Module):
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.config.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (batch * sequence_length, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
@@ -169,8 +169,7 @@ class DeepseekMoE(nn.Module):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(batch_size, sequence_length,
hidden_dim)
return final_hidden_states.view(num_tokens, hidden_dim)
class DeepseekAttention(nn.Module):