[BugFix] 1D query fix for MoE models (#3597)
This commit is contained in:
@@ -150,11 +150,11 @@ class DeepseekMoE(nn.Module):
|
||||
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.config.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w1,
|
||||
@@ -169,8 +169,7 @@ class DeepseekMoE(nn.Module):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(batch_size, sequence_length,
|
||||
hidden_dim)
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class DeepseekAttention(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user