[Bugfix] Fix GLM-4 MoE router logits dtype for data parallel chunking (#31055)
Signed-off-by: ReinforcedKnowledge <reinforced.knowledge@gmail.com>
This commit is contained in:
@@ -1006,6 +1006,9 @@ class FusedMoEConfig:
|
|||||||
# The activation type.
|
# The activation type.
|
||||||
in_dtype: torch.dtype
|
in_dtype: torch.dtype
|
||||||
|
|
||||||
|
# Defaults to in_dtype if not specified.
|
||||||
|
router_logits_dtype: torch.dtype | None = None
|
||||||
|
|
||||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||||
|
|
||||||
has_bias: bool = False
|
has_bias: bool = False
|
||||||
@@ -1022,6 +1025,9 @@ class FusedMoEConfig:
|
|||||||
|
|
||||||
assert self.max_num_tokens > 0
|
assert self.max_num_tokens > 0
|
||||||
|
|
||||||
|
if self.router_logits_dtype is None:
|
||||||
|
self.router_logits_dtype = self.in_dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tp_size(self):
|
def tp_size(self):
|
||||||
return self.moe_parallel_config.tp_size
|
return self.moe_parallel_config.tp_size
|
||||||
|
|||||||
@@ -314,6 +314,7 @@ class FusedMoE(CustomOp):
|
|||||||
renormalize: Whether to renormalize the logits in the fused_moe kernel
|
renormalize: Whether to renormalize the logits in the fused_moe kernel
|
||||||
quant_config: Quantization configure.
|
quant_config: Quantization configure.
|
||||||
enable_eplb: Whether to enable expert parallelism load balancer.
|
enable_eplb: Whether to enable expert parallelism load balancer.
|
||||||
|
router_logits_dtype: Data type for router logits buffers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -348,6 +349,7 @@ class FusedMoE(CustomOp):
|
|||||||
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
||||||
n_shared_experts: int | None = None,
|
n_shared_experts: int | None = None,
|
||||||
routing_method_type: int | None = None,
|
routing_method_type: int | None = None,
|
||||||
|
router_logits_dtype: torch.dtype | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -559,6 +561,7 @@ class FusedMoE(CustomOp):
|
|||||||
num_local_experts=self.local_num_experts,
|
num_local_experts=self.local_num_experts,
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
in_dtype=moe_in_dtype,
|
in_dtype=moe_in_dtype,
|
||||||
|
router_logits_dtype=router_logits_dtype,
|
||||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||||
has_bias=has_bias,
|
has_bias=has_bias,
|
||||||
is_act_and_mul=is_act_and_mul,
|
is_act_and_mul=is_act_and_mul,
|
||||||
@@ -1509,7 +1512,9 @@ class FusedMoE(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.batched_router_logits = torch.zeros(
|
self.batched_router_logits = torch.zeros(
|
||||||
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
logits_shape,
|
||||||
|
dtype=moe.router_logits_dtype,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def select_experts(
|
def select_experts(
|
||||||
|
|||||||
@@ -197,6 +197,7 @@ class Glm4MoE(nn.Module):
|
|||||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.n_redundant_experts,
|
num_redundant_experts=self.n_redundant_experts,
|
||||||
|
router_logits_dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
Reference in New Issue
Block a user