Modularize fused experts and integrate PPLX kernels (#15956)
This commit is contained in:
@@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm, deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8, per_token_quant_int8)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
|
||||
== B_scale.shape[-2])
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
|
||||
== B_scale.shape[-1])
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.shape[0]
|
||||
num_tokens = M * top_k
|
||||
|
||||
@@ -855,6 +870,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
@@ -865,10 +881,11 @@ def fused_topk(
|
||||
topk,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(
|
||||
M,
|
||||
topk,
|
||||
dtype=torch.int32 if indices_type is None else indices_type,
|
||||
device=hidden_states.device)
|
||||
token_expert_indices = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
@@ -962,6 +979,20 @@ def get_config_dtype_str(
|
||||
return None
|
||||
|
||||
|
||||
# TODO (bnell): use scalar_type instead of bools?
|
||||
def get_config_qtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
) -> Optional[torch.dtype]:
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
return None
|
||||
|
||||
|
||||
def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
if (allow_deep_gemm and use_fp8_w8a8
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
N = w1.shape[1]
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
assert apply_router_weight_on_input is False
|
||||
return deep_gemm_moe_fp8(
|
||||
@@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
@@ -1171,87 +1206,37 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def moe_kernel_prepare_input(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
def fused_experts_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if use_fp8_w8a8:
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
# If weights are per-channel (per_channel_quant=True), then
|
||||
# activations apply per-token quantization. Otherwise, assume
|
||||
# activation tensor-wise fp8 quantization, dynamic or static
|
||||
A, A_scale = ops.scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
|
||||
else:
|
||||
# activation block-wise fp8 quantization
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
elif use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
# activation channel-wise int8 quantization
|
||||
assert (per_channel_quant
|
||||
), "int8 quantization only supports block or channel-wise"
|
||||
A, A_scale = per_token_quant_int8(A)
|
||||
else:
|
||||
# activation block-wise int8 quantization
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||
2], "Hidden size mismatch"
|
||||
else:
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[2], (
|
||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
@@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
num_tokens = hidden_states.shape[0]
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
if global_num_experts == -1:
|
||||
@@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
@@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
B=w1,
|
||||
A_scale=a1_scale,
|
||||
B_scale=w1_scale,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
@@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
invoke_fused_moe_kernel(qcurr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
qa1_scale,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
@@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
B=w2,
|
||||
A_scale=a2_scale,
|
||||
B_scale=w2_scale,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
qa2_scale,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
@@ -1534,3 +1514,209 @@ def fused_moe(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_m: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.block_shape = block_shape
|
||||
self.block_m = block_m
|
||||
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
self.per_channel_quant = per_channel_quant
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
factor = num_experts if a.dim() == 3 else 1
|
||||
workspace1 = M * topk * max(N * 2, K) * factor
|
||||
workspace2 = M * topk * N * factor
|
||||
return (workspace1, workspace2, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
"Hidden size mismatch")
|
||||
else:
|
||||
assert hidden_states.size(-1) == w1.size(2), \
|
||||
(f"Hidden size mismatch {hidden_states.size(-1)} "
|
||||
f"!= {w1.size(2)}")
|
||||
|
||||
assert hidden_states.is_contiguous(
|
||||
), "Hidden_states must be contiguous"
|
||||
assert hidden_states.dim() == 2
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
||||
]
|
||||
|
||||
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids)
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
num_tokens,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
elif hidden_states.dtype == torch.float8_e4m3fn:
|
||||
compute_type = tl.bfloat16
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
intermediate_cache1 = _resize_cache(workspace13,
|
||||
(num_tokens, top_k_num, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2,
|
||||
(num_tokens * top_k_num, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace13,
|
||||
(num_tokens, top_k_num, K))
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
|
||||
global_num_experts, expert_map))
|
||||
|
||||
invoke_fused_moe_kernel(hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
self.activation(activation, intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
|
||||
self.block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
qtype = get_config_qtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
)
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(
|
||||
quant_dtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
TritonExperts(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user