[Refactor] Remove unused _moe_permute function (#33108)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -10,8 +10,6 @@ from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_permute,
|
||||
_moe_unpermute_and_reduce,
|
||||
moe_permute,
|
||||
moe_unpermute,
|
||||
)
|
||||
@@ -41,7 +39,6 @@ def benchmark_permute(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False,
|
||||
) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
@@ -64,29 +61,14 @@ def benchmark_permute(
|
||||
input_gating.copy_(gating_output[i])
|
||||
|
||||
def run():
|
||||
if use_customized_permute:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
else:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
) = _moe_permute(qhidden_states, None, topk_ids, num_experts, None, 16)
|
||||
moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
@@ -131,11 +113,9 @@ def benchmark_unpermute(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False,
|
||||
) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
output_hidden_states = torch.empty_like(hidden_states)
|
||||
if use_fp8_w8a8:
|
||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||
@@ -150,78 +130,37 @@ def benchmark_unpermute(
|
||||
)
|
||||
|
||||
def prepare():
|
||||
if use_customized_permute:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (
|
||||
permuted_hidden_states.to(dtype),
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
)
|
||||
else:
|
||||
(
|
||||
permuted_qhidden_states,
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
) = _moe_permute(
|
||||
qhidden_states, None, topk_ids, num_experts, None, block_m=16
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (
|
||||
permuted_qhidden_states.to(dtype),
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
)
|
||||
(
|
||||
permuted_hidden_states,
|
||||
_,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
_,
|
||||
) = moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (
|
||||
permuted_hidden_states.to(dtype),
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
)
|
||||
|
||||
def run(input: tuple):
|
||||
if use_customized_permute:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = input
|
||||
output = torch.empty_like(hidden_states)
|
||||
moe_unpermute(
|
||||
output,
|
||||
permuted_hidden_states,
|
||||
topk_weights,
|
||||
inv_perm_idx,
|
||||
first_token_off,
|
||||
)
|
||||
else:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
) = input
|
||||
_moe_unpermute_and_reduce(
|
||||
output_hidden_states,
|
||||
permuted_hidden_states,
|
||||
inv_perm,
|
||||
topk_weights,
|
||||
True,
|
||||
)
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx) = input
|
||||
output = torch.empty_like(hidden_states)
|
||||
moe_unpermute(
|
||||
output,
|
||||
permuted_hidden_states,
|
||||
topk_weights,
|
||||
inv_perm_idx,
|
||||
first_token_off,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
input = prepare()
|
||||
@@ -276,8 +215,7 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_customized_permute: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
) -> tuple[float, float]:
|
||||
set_random_seed(self.seed)
|
||||
|
||||
permute_time = benchmark_permute(
|
||||
@@ -289,7 +227,6 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute,
|
||||
)
|
||||
unpermute_time = benchmark_unpermute(
|
||||
num_tokens,
|
||||
@@ -300,7 +237,6 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute,
|
||||
)
|
||||
return permute_time, unpermute_time
|
||||
|
||||
@@ -347,7 +283,6 @@ def main(args: argparse.Namespace):
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
use_customized_permute = args.use_customized_permute
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
@@ -399,7 +334,6 @@ def main(args: argparse.Namespace):
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_customized_permute,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
@@ -419,7 +353,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
)
|
||||
parser.add_argument("--use-customized-permute", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
|
||||
@@ -3,70 +3,6 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
|
||||
|
||||
|
||||
def _moe_permute(
|
||||
curr_hidden_states: torch.Tensor,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
curr_topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
block_m: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||
"""
|
||||
top_k_num = curr_topk_ids.size(1)
|
||||
|
||||
tokens_in_chunk = curr_hidden_states.size(0)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
|
||||
)
|
||||
|
||||
inv_perm: torch.Tensor | None = None
|
||||
|
||||
num_tokens = top_k_num * tokens_in_chunk
|
||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||
|
||||
# Permute according to sorted token ids.
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
|
||||
curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num)
|
||||
|
||||
if a1q_scale is not None:
|
||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||
|
||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm)
|
||||
|
||||
|
||||
def _moe_unpermute_and_reduce(
|
||||
out: torch.Tensor,
|
||||
curr_hidden: torch.Tensor,
|
||||
inv_perm: torch.Tensor | None,
|
||||
topk_weight: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Unpermute the final result and apply topk_weights, then perform the final
|
||||
reduction on the hidden states.
|
||||
"""
|
||||
M, topk = topk_weight.size()
|
||||
K = curr_hidden.size(-1)
|
||||
if inv_perm is not None:
|
||||
curr_hidden = curr_hidden[inv_perm, ...]
|
||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||
if not apply_router_weight_on_input:
|
||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||
ops.moe_sum(curr_hidden, out)
|
||||
|
||||
|
||||
def moe_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user