Add unit test for Mixtral MoE layer (#2677)

This commit is contained in:
Philipp Moritz
2024-01-31 14:34:17 -08:00
committed by GitHub
parent 89efcf1ce5
commit d0d93b92b1
5 changed files with 119 additions and 55 deletions

View File

@@ -235,7 +235,9 @@ def fused_moe(hidden_states: torch.Tensor,
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape
E, N, _ = w1.shape

View File

@@ -70,13 +70,14 @@ class MixtralMoE(nn.Module):
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // tp_size
self.intermediate_size = intermediate_size // self.tp_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -141,8 +142,9 @@ class MixtralMoE(nn.Module):
selected_experts,
inplace=True)
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(batch_size, sequence_length,
hidden_size)