diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index e70f5c0a1..8920e90fc 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -8,7 +8,7 @@ import ray import torch from transformers import AutoConfig -from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute, _moe_unpermute_and_reduce, @@ -86,9 +86,7 @@ def benchmark_permute( sorted_token_ids, expert_ids, inv_perm, - ) = _moe_permute( - qhidden_states, None, topk_ids, num_experts, None, align_block_size - ) + ) = _moe_permute(qhidden_states, None, topk_ids, num_experts, None, 16) # JIT compilation & warmup run() @@ -182,7 +180,7 @@ def benchmark_unpermute( expert_ids, inv_perm, ) = _moe_permute( - qhidden_states, None, topk_ids, num_experts, None, align_block_size + qhidden_states, None, topk_ids, num_experts, None, block_m=16 ) # convert to fp16/bf16 as gemm output return (