[Bugfix] Fix GLM-4 MoE router logits dtype for data parallel chunking (#31055)
Signed-off-by: ReinforcedKnowledge <reinforced.knowledge@gmail.com>
This commit is contained in:
@@ -1006,6 +1006,9 @@ class FusedMoEConfig:
|
||||
# The activation type.
|
||||
in_dtype: torch.dtype
|
||||
|
||||
# Defaults to in_dtype if not specified.
|
||||
router_logits_dtype: torch.dtype | None = None
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
@@ -1022,6 +1025,9 @@ class FusedMoEConfig:
|
||||
|
||||
assert self.max_num_tokens > 0
|
||||
|
||||
if self.router_logits_dtype is None:
|
||||
self.router_logits_dtype = self.in_dtype
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@@ -314,6 +314,7 @@ class FusedMoE(CustomOp):
|
||||
renormalize: Whether to renormalize the logits in the fused_moe kernel
|
||||
quant_config: Quantization configure.
|
||||
enable_eplb: Whether to enable expert parallelism load balancer.
|
||||
router_logits_dtype: Data type for router logits buffers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -348,6 +349,7 @@ class FusedMoE(CustomOp):
|
||||
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
||||
n_shared_experts: int | None = None,
|
||||
routing_method_type: int | None = None,
|
||||
router_logits_dtype: torch.dtype | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -559,6 +561,7 @@ class FusedMoE(CustomOp):
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=moe_in_dtype,
|
||||
router_logits_dtype=router_logits_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
@@ -1509,7 +1512,9 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
logits_shape,
|
||||
dtype=moe.router_logits_dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
def select_experts(
|
||||
|
||||
@@ -197,6 +197,7 @@ class Glm4MoE(nn.Module):
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user