[Refactor] Remove align block size logic in moe_permute (#33449)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-02-06 13:57:06 -05:00
committed by GitHub
parent 16786da735
commit 77c09e1130
8 changed files with 38 additions and 297 deletions

View File

@@ -11,8 +11,6 @@ def moe_permute(
n_expert: int,
n_local_expert: int = -1,
expert_map: torch.Tensor | None = None,
align_block_size: int | None = None,
fill_invalid_expert: int = -1,
permuted_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
@@ -27,9 +25,6 @@ def moe_permute(
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function.
Returns:
@@ -37,12 +32,9 @@ def moe_permute(
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
of each expert for standard grouped gemm.
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.`
"""
n_token, n_hidden = hidden_states.size()
topk = topk_ids.size(1)
@@ -50,17 +42,6 @@ def moe_permute(
"permue kernel need hidden dim align to 16B"
)
permuted_row_size = n_token * topk
if align_block_size is not None:
permuted_row_size = (
(
permuted_row_size
+ n_expert * (align_block_size - 1)
+ align_block_size
- 1
)
// align_block_size
* align_block_size
)
if n_local_expert == -1:
n_local_expert = n_expert
if permuted_hidden_states is None:
@@ -78,12 +59,6 @@ def moe_permute(
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
).reshape((n_token, topk))
m_indices = torch.full(
(permuted_row_size,),
fill_invalid_expert,
dtype=torch.int32,
device=hidden_states.device,
)
expert_first_token_offset = torch.empty(
n_local_expert + 1, dtype=torch.int64, device=hidden_states.device
)
@@ -105,12 +80,10 @@ def moe_permute(
n_expert,
n_local_expert,
topk,
align_block_size,
permuted_hidden_states,
expert_first_token_offset,
inv_permuted_idx,
permuted_idx,
m_indices,
)
if a1q_scale is not None and a1q_scale.dim() > 1:
@@ -120,7 +93,7 @@ def moe_permute(
a1q_scale,
expert_first_token_offset,
inv_permuted_idx.flatten(),
m_indices,
permuted_idx,
)