Modularize fused experts and integrate PPLX kernels (#15956)

This commit is contained in:
bnellnm
2025-05-14 16:11:54 -04:00
committed by GitHub
parent 418d2f8bfb
commit f9c069c85e
42 changed files with 3830 additions and 660 deletions

View File

@@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
== B_scale.shape[-2])
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
== B_scale.shape[-1])
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
M = A.shape[0]
num_tokens = M * top_k
@@ -855,6 +870,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: Optional[torch.dtype] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
@@ -865,10 +881,11 @@ def fused_topk(
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device)
token_expert_indices = torch.empty(M,
topk,
dtype=torch.int32,
@@ -962,6 +979,20 @@ def get_config_dtype_str(
return None
# TODO (bnell): use scalar_type instead of bools?
def get_config_qtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
) -> Optional[torch.dtype]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
return None
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.shape[1]
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
@@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
return dispatch_fused_experts_func(inplace)(
@@ -1171,87 +1206,37 @@ def fused_experts(hidden_states: torch.Tensor,
block_shape=block_shape)
def moe_kernel_prepare_input(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if use_fp8_w8a8:
assert B_scale is not None
if block_shape is None:
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8 quantization, dynamic or static
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
# activation channel-wise int8 quantization
assert (per_channel_quant
), "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
# activation block-wise int8 quantization
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None
assert block_shape is None or block_shape[0] == 0
else:
assert A_scale is None
assert B_scale is None
return A, A_scale
def fused_experts_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None):
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
else:
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
@@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
torch.float32, torch.float16, torch.bfloat16
]
num_tokens, _ = hidden_states.shape
num_tokens = hidden_states.shape[0]
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
@@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype)
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
@@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
B=w1,
A_scale=a1_scale,
B_scale=w1_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
qtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
@@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
invoke_fused_moe_kernel(qcurr_hidden_states,
w1,
intermediate_cache1,
qa1_scale,
a1q_scale,
w1_scale,
w1_zp,
curr_topk_weights,
@@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
B=w2,
A_scale=a2_scale,
B_scale=w2_scale,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
qtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
qa2_scale,
a2q_scale,
w2_scale,
w2_zp,
curr_topk_weights,
@@ -1534,3 +1514,209 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
):
super().__init__()
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.block_m = block_m
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
self.per_channel_quant = per_channel_quant
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
factor = num_experts if a.dim() == 3 else 1
workspace1 = M * topk * max(N * 2, K) * factor
workspace2 = M * topk * N * factor
return (workspace1, workspace2, a.dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
"Hidden size mismatch")
else:
assert hidden_states.size(-1) == w1.size(2), \
(f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}")
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert hidden_states.dim() == 2
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)
if global_num_experts == -1:
global_num_experts = E
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
w1.shape,
w2.shape,
top_k_num,
config_dtype,
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn:
compute_type = tl.bfloat16
else:
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13,
(num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2,
(num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map))
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1q_scale,
w1_scale,
w1_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
self.activation(activation, intermediate_cache2,
intermediate_cache1.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
self.block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
return intermediate_cache3
def modular_triton_fused_moe(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> mk.FusedMoEModularKernel:
qtype = get_config_qtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
)
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
quant_dtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
),
TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
),
)