[Refactor] Remove align block size logic in moe_permute (#33449)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-02-06 13:57:06 -05:00
committed by GitHub
parent 16786da735
commit 77c09e1130
8 changed files with 38 additions and 297 deletions

View File

@@ -44,10 +44,8 @@ def benchmark_permute(
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
# output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else:
align_block_size = None
qhidden_states = hidden_states
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
@@ -67,7 +65,6 @@ def benchmark_permute(
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
# JIT compilation & warmup
@@ -117,10 +114,8 @@ def benchmark_unpermute(
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else:
align_block_size = None
qhidden_states = hidden_states
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
@@ -142,7 +137,6 @@ def benchmark_unpermute(
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
# convert to fp16/bf16 as gemm output
return (