[Refactor] Remove align block size logic in moe_permute (#33449)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-02-06 13:57:06 -05:00
committed by GitHub
parent 16786da735
commit 77c09e1130
8 changed files with 38 additions and 297 deletions

View File

@@ -40,10 +40,8 @@ def torch_permute(
n_local_expert: int,
start_expert: int,
expert_map: torch.Tensor | None = None,
align_block_size: int | None = None,
fill_invalid_expert: int = -1,
) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
n_token = hidden_states.shape[0]
if expert_map is not None:
is_local_expert = expert_map[topk_ids] != -1
not_local_expert = expert_map[topk_ids] == -1
@@ -70,107 +68,19 @@ def torch_permute(
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
valid_row_idx = []
if align_block_size is None:
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
permuted_row_size = permuted_hidden_states.shape[0]
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
m_indices[first_token_offset:last_token_offset] = i - 1
src_row_id2dst_row_id_map = torch.arange(
0, n_token * topk, device="cuda", dtype=torch.int32
)[src2dst_idx].reshape((n_token, topk))
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
return [
permuted_hidden_states,
expert_first_token_offset,
src_row_id2dst_row_id_map,
dst_row_id2src_row_id_map,
m_indices,
valid_row_idx,
]
else:
permuted_row_size = (
(topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
// align_block_size
* align_block_size
)
permuted_idx = torch.full(
(permuted_row_size,),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device,
)
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
)
align_src_row_id2dst_row_id = torch.empty(
n_token * topk, device="cuda", dtype=torch.int32
)
align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
n_token_in_expert = last_token_offset - first_token_offset
align_expert_first_token_offset[i] = (
align_expert_first_token_offset[i - 1]
+ (n_token_in_expert + align_block_size - 1)
// align_block_size
* align_block_size
)
align_first_token_offset = align_expert_first_token_offset[i - 1]
align_last_token_offset = align_expert_first_token_offset[i]
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
first_token_offset : first_token_offset + n_token_in_expert
]
# store token in current expert with align_first_token_offset
permuted_hidden_states[
align_first_token_offset : align_first_token_offset + n_token_in_expert,
...,
] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
permuted_idx[
align_first_token_offset : align_first_token_offset + n_token_in_expert
] = dst_row_id2src_row_id_in_expert
# set current expert m_indices
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
valid_row_idx += [
i
for i in range(
align_first_token_offset,
align_first_token_offset + n_token_in_expert,
)
]
# get align_src_row_id2dst_row_id
for i in range(n_token * topk):
eid = sorted_topk_ids[i]
if eid >= n_local_expert:
# check token not in local expert
align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
continue
first_token_offset = expert_first_token_offset[eid]
align_first_token_offset = align_expert_first_token_offset[eid]
token_offset = i - first_token_offset
align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
(n_token, topk)
)
return [
permuted_hidden_states,
align_expert_first_token_offset,
align_src_row_id2dst_row_id,
permuted_idx,
m_indices,
valid_row_idx,
]
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
src_row_id2dst_row_id_map = torch.arange(
0, n_token * topk, device="cuda", dtype=torch.int32
)[src2dst_idx].reshape((n_token, topk))
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
return [
permuted_hidden_states,
expert_first_token_offset,
src_row_id2dst_row_id_map,
dst_row_id2src_row_id_map,
valid_row_idx,
]
def torch_unpermute(
@@ -207,7 +117,6 @@ def torch_unpermute(
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("align_block_size", [None, 128])
def test_moe_permute_unpermute(
n_token: int,
n_hidden: int,
@@ -215,11 +124,9 @@ def test_moe_permute_unpermute(
n_expert: int,
ep_size: int,
dtype: torch.dtype,
align_block_size: int | None,
):
if not moe_permute_unpermute_supported():
pytest.skip("moe_permute_unpermute is not supported on this platform.")
fill_invalid_expert = 0
ep_rank = np.random.randint(0, ep_size)
expert_map = None
n_local_expert = n_expert
@@ -238,7 +145,6 @@ def test_moe_permute_unpermute(
gold_expert_first_token_offset,
gold_inv_permuted_idx,
gold_permuted_idx,
gold_m_indices,
valid_row_idx,
) = torch_permute(
hidden_states,
@@ -249,8 +155,6 @@ def test_moe_permute_unpermute(
n_local_expert,
start_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
)
(
@@ -258,7 +162,7 @@ def test_moe_permute_unpermute(
_,
expert_first_token_offset,
inv_permuted_idx,
m_indices,
_,
) = moe_permute(
hidden_states=hidden_states,
a1q_scale=None,
@@ -266,8 +170,6 @@ def test_moe_permute_unpermute(
n_expert=n_expert,
n_local_expert=n_local_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
)
# check expert_first_token_offset
@@ -278,11 +180,6 @@ def test_moe_permute_unpermute(
torch.testing.assert_close(
gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
)
# check mindice
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if align_block_size is not None:
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(