[Refactor] Remove unused _moe_permute function (#33108)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-01-26 16:06:45 -05:00
committed by GitHub
parent ebe0ba91db
commit 8f987883cb
2 changed files with 38 additions and 169 deletions

View File

@@ -10,8 +10,6 @@ from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute,
_moe_unpermute_and_reduce,
moe_permute, moe_permute,
moe_unpermute, moe_unpermute,
) )
@@ -41,7 +39,6 @@ def benchmark_permute(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
use_customized_permute: bool = False,
) -> float: ) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
@@ -64,14 +61,7 @@ def benchmark_permute(
input_gating.copy_(gating_output[i]) input_gating.copy_(gating_output[i])
def run(): def run():
if use_customized_permute: moe_permute(
(
permuted_hidden_states,
a1q_scale,
first_token_off,
inv_perm_idx,
m_indices,
) = moe_permute(
qhidden_states, qhidden_states,
a1q_scale=None, a1q_scale=None,
topk_ids=topk_ids, topk_ids=topk_ids,
@@ -79,14 +69,6 @@ def benchmark_permute(
expert_map=None, expert_map=None,
align_block_size=align_block_size, align_block_size=align_block_size,
) )
else:
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(qhidden_states, None, topk_ids, num_experts, None, 16)
# JIT compilation & warmup # JIT compilation & warmup
run() run()
@@ -131,11 +113,9 @@ def benchmark_unpermute(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
use_customized_permute: bool = False,
) -> float: ) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8: if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None) qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
@@ -150,13 +130,12 @@ def benchmark_unpermute(
) )
def prepare(): def prepare():
if use_customized_permute:
( (
permuted_hidden_states, permuted_hidden_states,
a1q_scale, _,
first_token_off, first_token_off,
inv_perm_idx, inv_perm_idx,
m_indices, _,
) = moe_permute( ) = moe_permute(
qhidden_states, qhidden_states,
a1q_scale=None, a1q_scale=None,
@@ -170,35 +149,10 @@ def benchmark_unpermute(
permuted_hidden_states.to(dtype), permuted_hidden_states.to(dtype),
first_token_off, first_token_off,
inv_perm_idx, inv_perm_idx,
m_indices,
)
else:
(
permuted_qhidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(
qhidden_states, None, topk_ids, num_experts, None, block_m=16
)
# convert to fp16/bf16 as gemm output
return (
permuted_qhidden_states.to(dtype),
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) )
def run(input: tuple): def run(input: tuple):
if use_customized_permute: (permuted_hidden_states, first_token_off, inv_perm_idx) = input
(
permuted_hidden_states,
first_token_off,
inv_perm_idx,
m_indices,
) = input
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
moe_unpermute( moe_unpermute(
output, output,
@@ -207,21 +161,6 @@ def benchmark_unpermute(
inv_perm_idx, inv_perm_idx,
first_token_off, first_token_off,
) )
else:
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = input
_moe_unpermute_and_reduce(
output_hidden_states,
permuted_hidden_states,
inv_perm,
topk_weights,
True,
)
# JIT compilation & warmup # JIT compilation & warmup
input = prepare() input = prepare()
@@ -276,8 +215,7 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_customized_permute: bool = False, ) -> tuple[float, float]:
) -> tuple[dict[str, int], float]:
set_random_seed(self.seed) set_random_seed(self.seed)
permute_time = benchmark_permute( permute_time = benchmark_permute(
@@ -289,7 +227,6 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=100, num_iters=100,
use_customized_permute=use_customized_permute,
) )
unpermute_time = benchmark_unpermute( unpermute_time = benchmark_unpermute(
num_tokens, num_tokens,
@@ -300,7 +237,6 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=100, num_iters=100,
use_customized_permute=use_customized_permute,
) )
return permute_time, unpermute_time return permute_time, unpermute_time
@@ -347,7 +283,6 @@ def main(args: argparse.Namespace):
dtype = torch.float16 if current_platform.is_rocm() else config.dtype dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = [
@@ -399,7 +334,6 @@ def main(args: argparse.Namespace):
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_customized_permute,
) )
for batch_size in batch_sizes for batch_size in batch_sizes
], ],
@@ -419,7 +353,6 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
) )
parser.add_argument("--use-customized-permute", action="store_true")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--trust-remote-code", action="store_true")

View File

@@ -3,70 +3,6 @@
import torch import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: torch.Tensor | None,
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: torch.Tensor | None,
block_m: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.size(1)
tokens_in_chunk = curr_hidden_states.size(0)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
)
inv_perm: torch.Tensor | None = None
num_tokens = top_k_num * tokens_in_chunk
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: torch.Tensor | None,
topk_weight: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.size()
K = curr_hidden.size(-1)
if inv_perm is not None:
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
if not apply_router_weight_on_input:
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def moe_permute( def moe_permute(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,