[Refactor] Remove unused _moe_permute function (#33108)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -10,8 +10,6 @@ from transformers import AutoConfig
|
|||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||||
_moe_permute,
|
|
||||||
_moe_unpermute_and_reduce,
|
|
||||||
moe_permute,
|
moe_permute,
|
||||||
moe_unpermute,
|
moe_unpermute,
|
||||||
)
|
)
|
||||||
@@ -41,7 +39,6 @@ def benchmark_permute(
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
num_iters: int = 100,
|
num_iters: int = 100,
|
||||||
use_customized_permute: bool = False,
|
|
||||||
) -> float:
|
) -> float:
|
||||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
@@ -64,14 +61,7 @@ def benchmark_permute(
|
|||||||
input_gating.copy_(gating_output[i])
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
if use_customized_permute:
|
moe_permute(
|
||||||
(
|
|
||||||
permuted_hidden_states,
|
|
||||||
a1q_scale,
|
|
||||||
first_token_off,
|
|
||||||
inv_perm_idx,
|
|
||||||
m_indices,
|
|
||||||
) = moe_permute(
|
|
||||||
qhidden_states,
|
qhidden_states,
|
||||||
a1q_scale=None,
|
a1q_scale=None,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
@@ -79,14 +69,6 @@ def benchmark_permute(
|
|||||||
expert_map=None,
|
expert_map=None,
|
||||||
align_block_size=align_block_size,
|
align_block_size=align_block_size,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
(
|
|
||||||
permuted_hidden_states,
|
|
||||||
a1q_scale,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
inv_perm,
|
|
||||||
) = _moe_permute(qhidden_states, None, topk_ids, num_experts, None, 16)
|
|
||||||
|
|
||||||
# JIT compilation & warmup
|
# JIT compilation & warmup
|
||||||
run()
|
run()
|
||||||
@@ -131,11 +113,9 @@ def benchmark_unpermute(
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
num_iters: int = 100,
|
num_iters: int = 100,
|
||||||
use_customized_permute: bool = False,
|
|
||||||
) -> float:
|
) -> float:
|
||||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
output_hidden_states = torch.empty_like(hidden_states)
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||||
@@ -150,13 +130,12 @@ def benchmark_unpermute(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare():
|
def prepare():
|
||||||
if use_customized_permute:
|
|
||||||
(
|
(
|
||||||
permuted_hidden_states,
|
permuted_hidden_states,
|
||||||
a1q_scale,
|
_,
|
||||||
first_token_off,
|
first_token_off,
|
||||||
inv_perm_idx,
|
inv_perm_idx,
|
||||||
m_indices,
|
_,
|
||||||
) = moe_permute(
|
) = moe_permute(
|
||||||
qhidden_states,
|
qhidden_states,
|
||||||
a1q_scale=None,
|
a1q_scale=None,
|
||||||
@@ -170,35 +149,10 @@ def benchmark_unpermute(
|
|||||||
permuted_hidden_states.to(dtype),
|
permuted_hidden_states.to(dtype),
|
||||||
first_token_off,
|
first_token_off,
|
||||||
inv_perm_idx,
|
inv_perm_idx,
|
||||||
m_indices,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
permuted_qhidden_states,
|
|
||||||
a1q_scale,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
inv_perm,
|
|
||||||
) = _moe_permute(
|
|
||||||
qhidden_states, None, topk_ids, num_experts, None, block_m=16
|
|
||||||
)
|
|
||||||
# convert to fp16/bf16 as gemm output
|
|
||||||
return (
|
|
||||||
permuted_qhidden_states.to(dtype),
|
|
||||||
a1q_scale,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
inv_perm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(input: tuple):
|
def run(input: tuple):
|
||||||
if use_customized_permute:
|
(permuted_hidden_states, first_token_off, inv_perm_idx) = input
|
||||||
(
|
|
||||||
permuted_hidden_states,
|
|
||||||
first_token_off,
|
|
||||||
inv_perm_idx,
|
|
||||||
m_indices,
|
|
||||||
) = input
|
|
||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty_like(hidden_states)
|
||||||
moe_unpermute(
|
moe_unpermute(
|
||||||
output,
|
output,
|
||||||
@@ -207,21 +161,6 @@ def benchmark_unpermute(
|
|||||||
inv_perm_idx,
|
inv_perm_idx,
|
||||||
first_token_off,
|
first_token_off,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
(
|
|
||||||
permuted_hidden_states,
|
|
||||||
a1q_scale,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
inv_perm,
|
|
||||||
) = input
|
|
||||||
_moe_unpermute_and_reduce(
|
|
||||||
output_hidden_states,
|
|
||||||
permuted_hidden_states,
|
|
||||||
inv_perm,
|
|
||||||
topk_weights,
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# JIT compilation & warmup
|
# JIT compilation & warmup
|
||||||
input = prepare()
|
input = prepare()
|
||||||
@@ -276,8 +215,7 @@ class BenchmarkWorker:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
use_customized_permute: bool = False,
|
) -> tuple[float, float]:
|
||||||
) -> tuple[dict[str, int], float]:
|
|
||||||
set_random_seed(self.seed)
|
set_random_seed(self.seed)
|
||||||
|
|
||||||
permute_time = benchmark_permute(
|
permute_time = benchmark_permute(
|
||||||
@@ -289,7 +227,6 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=100,
|
num_iters=100,
|
||||||
use_customized_permute=use_customized_permute,
|
|
||||||
)
|
)
|
||||||
unpermute_time = benchmark_unpermute(
|
unpermute_time = benchmark_unpermute(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -300,7 +237,6 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=100,
|
num_iters=100,
|
||||||
use_customized_permute=use_customized_permute,
|
|
||||||
)
|
)
|
||||||
return permute_time, unpermute_time
|
return permute_time, unpermute_time
|
||||||
|
|
||||||
@@ -347,7 +283,6 @@ def main(args: argparse.Namespace):
|
|||||||
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
||||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
use_customized_permute = args.use_customized_permute
|
|
||||||
|
|
||||||
if args.batch_size is None:
|
if args.batch_size is None:
|
||||||
batch_sizes = [
|
batch_sizes = [
|
||||||
@@ -399,7 +334,6 @@ def main(args: argparse.Namespace):
|
|||||||
dtype,
|
dtype,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
use_customized_permute,
|
|
||||||
)
|
)
|
||||||
for batch_size in batch_sizes
|
for batch_size in batch_sizes
|
||||||
],
|
],
|
||||||
@@ -419,7 +353,6 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||||
)
|
)
|
||||||
parser.add_argument("--use-customized-permute", action="store_true")
|
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
parser.add_argument("--trust-remote-code", action="store_true")
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||||||
|
|||||||
@@ -3,70 +3,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
|
||||||
moe_align_block_size,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
|
|
||||||
|
|
||||||
|
|
||||||
def _moe_permute(
|
|
||||||
curr_hidden_states: torch.Tensor,
|
|
||||||
a1q_scale: torch.Tensor | None,
|
|
||||||
curr_topk_ids: torch.Tensor,
|
|
||||||
global_num_experts: int,
|
|
||||||
expert_map: torch.Tensor | None,
|
|
||||||
block_m: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
|
||||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
|
||||||
"""
|
|
||||||
top_k_num = curr_topk_ids.size(1)
|
|
||||||
|
|
||||||
tokens_in_chunk = curr_hidden_states.size(0)
|
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
||||||
curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True
|
|
||||||
)
|
|
||||||
|
|
||||||
inv_perm: torch.Tensor | None = None
|
|
||||||
|
|
||||||
num_tokens = top_k_num * tokens_in_chunk
|
|
||||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
|
||||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
|
||||||
|
|
||||||
# Permute according to sorted token ids.
|
|
||||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
|
||||||
|
|
||||||
curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num)
|
|
||||||
|
|
||||||
if a1q_scale is not None:
|
|
||||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
|
||||||
|
|
||||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm)
|
|
||||||
|
|
||||||
|
|
||||||
def _moe_unpermute_and_reduce(
|
|
||||||
out: torch.Tensor,
|
|
||||||
curr_hidden: torch.Tensor,
|
|
||||||
inv_perm: torch.Tensor | None,
|
|
||||||
topk_weight: torch.Tensor,
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Unpermute the final result and apply topk_weights, then perform the final
|
|
||||||
reduction on the hidden states.
|
|
||||||
"""
|
|
||||||
M, topk = topk_weight.size()
|
|
||||||
K = curr_hidden.size(-1)
|
|
||||||
if inv_perm is not None:
|
|
||||||
curr_hidden = curr_hidden[inv_perm, ...]
|
|
||||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
|
||||||
if not apply_router_weight_on_input:
|
|
||||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
|
||||||
ops.moe_sum(curr_hidden, out)
|
|
||||||
|
|
||||||
|
|
||||||
def moe_permute(
|
def moe_permute(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user