Add routed_scaling_factor to MoE grouped topk (#23123)

Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Xin Yang
2025-08-29 21:36:48 -07:00
committed by GitHub
parent 5b31cb1781
commit 8fb85b7bb6
19 changed files with 77 additions and 4 deletions

View File

@@ -1011,7 +1011,8 @@ def grouped_topk(
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights * routed_scaling_factor
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@@ -1790,8 +1791,8 @@ def fused_moe(
Defaults to False.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.