diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 8920e90fc..8c90dd725 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -10,8 +10,6 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_permute, - _moe_unpermute_and_reduce, moe_permute, moe_unpermute, ) @@ -41,7 +39,6 @@ def benchmark_permute( use_fp8_w8a8: bool, use_int8_w8a16: bool, num_iters: int = 100, - use_customized_permute: bool = False, ) -> float: # init_dtype = torch.float16 if use_fp8_w8a8 else dtype hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) @@ -64,29 +61,14 @@ def benchmark_permute( input_gating.copy_(gating_output[i]) def run(): - if use_customized_permute: - ( - permuted_hidden_states, - a1q_scale, - first_token_off, - inv_perm_idx, - m_indices, - ) = moe_permute( - qhidden_states, - a1q_scale=None, - topk_ids=topk_ids, - n_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) - else: - ( - permuted_hidden_states, - a1q_scale, - sorted_token_ids, - expert_ids, - inv_perm, - ) = _moe_permute(qhidden_states, None, topk_ids, num_experts, None, 16) + moe_permute( + qhidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) # JIT compilation & warmup run() @@ -131,11 +113,9 @@ def benchmark_unpermute( use_fp8_w8a8: bool, use_int8_w8a16: bool, num_iters: int = 100, - use_customized_permute: bool = False, ) -> float: # init_dtype = torch.float16 if use_fp8_w8a8 else dtype hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) - output_hidden_states = torch.empty_like(hidden_states) if use_fp8_w8a8: align_block_size = 128 # deepgemm needs 128 m aligned block qhidden_states, scale = _fp8_quantize(hidden_states, None, None) @@ -150,78 +130,37 @@ def benchmark_unpermute( ) def prepare(): - if use_customized_permute: - ( - permuted_hidden_states, - a1q_scale, - first_token_off, - inv_perm_idx, - m_indices, - ) = moe_permute( - qhidden_states, - a1q_scale=None, - topk_ids=topk_ids, - n_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) - # convert to fp16/bf16 as gemm output - return ( - permuted_hidden_states.to(dtype), - first_token_off, - inv_perm_idx, - m_indices, - ) - else: - ( - permuted_qhidden_states, - a1q_scale, - sorted_token_ids, - expert_ids, - inv_perm, - ) = _moe_permute( - qhidden_states, None, topk_ids, num_experts, None, block_m=16 - ) - # convert to fp16/bf16 as gemm output - return ( - permuted_qhidden_states.to(dtype), - a1q_scale, - sorted_token_ids, - expert_ids, - inv_perm, - ) + ( + permuted_hidden_states, + _, + first_token_off, + inv_perm_idx, + _, + ) = moe_permute( + qhidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + # convert to fp16/bf16 as gemm output + return ( + permuted_hidden_states.to(dtype), + first_token_off, + inv_perm_idx, + ) def run(input: tuple): - if use_customized_permute: - ( - permuted_hidden_states, - first_token_off, - inv_perm_idx, - m_indices, - ) = input - output = torch.empty_like(hidden_states) - moe_unpermute( - output, - permuted_hidden_states, - topk_weights, - inv_perm_idx, - first_token_off, - ) - else: - ( - permuted_hidden_states, - a1q_scale, - sorted_token_ids, - expert_ids, - inv_perm, - ) = input - _moe_unpermute_and_reduce( - output_hidden_states, - permuted_hidden_states, - inv_perm, - topk_weights, - True, - ) + (permuted_hidden_states, first_token_off, inv_perm_idx) = input + output = torch.empty_like(hidden_states) + moe_unpermute( + output, + permuted_hidden_states, + topk_weights, + inv_perm_idx, + first_token_off, + ) # JIT compilation & warmup input = prepare() @@ -276,8 +215,7 @@ class BenchmarkWorker: dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, - use_customized_permute: bool = False, - ) -> tuple[dict[str, int], float]: + ) -> tuple[float, float]: set_random_seed(self.seed) permute_time = benchmark_permute( @@ -289,7 +227,6 @@ class BenchmarkWorker: use_fp8_w8a8, use_int8_w8a16, num_iters=100, - use_customized_permute=use_customized_permute, ) unpermute_time = benchmark_unpermute( num_tokens, @@ -300,7 +237,6 @@ class BenchmarkWorker: use_fp8_w8a8, use_int8_w8a16, num_iters=100, - use_customized_permute=use_customized_permute, ) return permute_time, unpermute_time @@ -347,7 +283,6 @@ def main(args: argparse.Namespace): dtype = torch.float16 if current_platform.is_rocm() else config.dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" - use_customized_permute = args.use_customized_permute if args.batch_size is None: batch_sizes = [ @@ -399,7 +334,6 @@ def main(args: argparse.Namespace): dtype, use_fp8_w8a8, use_int8_w8a16, - use_customized_permute, ) for batch_size in batch_sizes ], @@ -419,7 +353,6 @@ if __name__ == "__main__": parser.add_argument( "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" ) - parser.add_argument("--use-customized-permute", action="store_true") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--trust-remote-code", action="store_true") diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 9dcdcc380..0c8cbd04b 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -3,70 +3,6 @@ import torch -from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size, -) -from vllm.model_executor.layers.fused_moe.utils import _fp8_perm - - -def _moe_permute( - curr_hidden_states: torch.Tensor, - a1q_scale: torch.Tensor | None, - curr_topk_ids: torch.Tensor, - global_num_experts: int, - expert_map: torch.Tensor | None, - block_m: int, -) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Determine the sorted_token_ids, expert_ids for the given problem size. - Permute the hidden states and scales according to `sorted_token_ids`. - """ - top_k_num = curr_topk_ids.size(1) - - tokens_in_chunk = curr_hidden_states.size(0) - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True - ) - - inv_perm: torch.Tensor | None = None - - num_tokens = top_k_num * tokens_in_chunk - expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] - - # Permute according to sorted token ids. - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - - curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num) - - if a1q_scale is not None: - a1q_scale = a1q_scale[sorted_token_ids // top_k_num] - - return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm) - - -def _moe_unpermute_and_reduce( - out: torch.Tensor, - curr_hidden: torch.Tensor, - inv_perm: torch.Tensor | None, - topk_weight: torch.Tensor, - apply_router_weight_on_input: bool, -) -> None: - """ - Unpermute the final result and apply topk_weights, then perform the final - reduction on the hidden states. - """ - M, topk = topk_weight.size() - K = curr_hidden.size(-1) - if inv_perm is not None: - curr_hidden = curr_hidden[inv_perm, ...] - curr_hidden = curr_hidden.view(-1, topk, K) - if not apply_router_weight_on_input: - curr_hidden.mul_(topk_weight.view(M, -1, 1)) - ops.moe_sum(curr_hidden, out) - def moe_permute( hidden_states: torch.Tensor,