[Refactor] Remove align block size logic in moe_permute (#33449)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -11,8 +11,6 @@ def moe_permute(
|
||||
n_expert: int,
|
||||
n_local_expert: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
align_block_size: int | None = None,
|
||||
fill_invalid_expert: int = -1,
|
||||
permuted_hidden_states: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -27,9 +25,6 @@ def moe_permute(
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- align_block_size (Optional[int]): align group gemm block size for deepgemm
|
||||
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
|
||||
to workaround DeepGemm unsupported -1 in m_indices
|
||||
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
|
||||
If None, the output tensor will be created in this function.
|
||||
Returns:
|
||||
@@ -37,12 +32,9 @@ def moe_permute(
|
||||
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
|
||||
if original scale not per-tensor scaling
|
||||
- expert_first_token_offset (torch.Tensor): offset of the first token
|
||||
of each expert for standard grouped gemm. if enable 'align_block_size'
|
||||
expert_first_token_offset will align up to 'align_block_size'.
|
||||
of each expert for standard grouped gemm.
|
||||
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
|
||||
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
|
||||
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
|
||||
the group which the j-th row of the LHS belong to.`
|
||||
"""
|
||||
n_token, n_hidden = hidden_states.size()
|
||||
topk = topk_ids.size(1)
|
||||
@@ -50,17 +42,6 @@ def moe_permute(
|
||||
"permue kernel need hidden dim align to 16B"
|
||||
)
|
||||
permuted_row_size = n_token * topk
|
||||
if align_block_size is not None:
|
||||
permuted_row_size = (
|
||||
(
|
||||
permuted_row_size
|
||||
+ n_expert * (align_block_size - 1)
|
||||
+ align_block_size
|
||||
- 1
|
||||
)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
if n_local_expert == -1:
|
||||
n_local_expert = n_expert
|
||||
if permuted_hidden_states is None:
|
||||
@@ -78,12 +59,6 @@ def moe_permute(
|
||||
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
|
||||
).reshape((n_token, topk))
|
||||
|
||||
m_indices = torch.full(
|
||||
(permuted_row_size,),
|
||||
fill_invalid_expert,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
expert_first_token_offset = torch.empty(
|
||||
n_local_expert + 1, dtype=torch.int64, device=hidden_states.device
|
||||
)
|
||||
@@ -105,12 +80,10 @@ def moe_permute(
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
topk,
|
||||
align_block_size,
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx,
|
||||
permuted_idx,
|
||||
m_indices,
|
||||
)
|
||||
|
||||
if a1q_scale is not None and a1q_scale.dim() > 1:
|
||||
@@ -120,7 +93,7 @@ def moe_permute(
|
||||
a1q_scale,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx.flatten(),
|
||||
m_indices,
|
||||
permuted_idx,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user