Add unit test for Mixtral MoE layer (#2677)
This commit is contained in:
@@ -235,7 +235,9 @@ def fused_moe(hidden_states: torch.Tensor,
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
M, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
|
||||
|
||||
@@ -70,13 +70,14 @@ class MixtralMoE(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size // tp_size
|
||||
self.intermediate_size = intermediate_size // self.tp_size
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@@ -141,8 +142,9 @@ class MixtralMoE(nn.Module):
|
||||
selected_experts,
|
||||
inplace=True)
|
||||
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(batch_size, sequence_length,
|
||||
hidden_size)
|
||||
|
||||
Reference in New Issue
Block a user