[Bugfix] Support 2D input shape in MoE layer (#6287)

This commit is contained in:
Woosuk Kwon
2024-07-10 06:03:16 -07:00
committed by GitHub
parent 8a924d2248
commit e72ae80b06
2 changed files with 7 additions and 4 deletions

View File

@@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
@@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
return final_hidden_states.view(orig_shape)
class Qwen2MoeAttention(nn.Module):