[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
ElizaWszola
2025-06-07 03:26:11 +02:00
committed by GitHub
parent 6e0cd10f72
commit 84166fee97
26 changed files with 918 additions and 409 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels."""
from typing import Optional
from typing import Callable, Optional
import torch
@@ -13,110 +13,109 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
from vllm.scalar_type import scalar_types
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def run_cutlass_moe_fp8(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation_callable: Callable,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
) -> torch.Tensor:
a1q = hidden_states
def __init__(
self,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.ab_strides1 = ab_strides1
self.c_strides1 = c_strides1
self.ab_strides2 = ab_strides2
self.c_strides2 = c_strides2
self.out_dtype = out_dtype
assert w1_scale is not None
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
if expert_num_tokens is None:
assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1"
else:
assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1"
assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[1], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[1], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim(
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
0], "Input scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a2_scale is None or a2_scale.dim(
) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[
0], "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
if expert_map is not None:
assert expert_num_tokens is None
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
# Note that K, N are transposed
N, K = K, N
workspace1 = M * topk * max(2 * N, K)
workspace2 = M * topk * N
return (workspace1, workspace2, self.out_dtype)
# We have two modes: PPLX and non-PPLX. We differentiate them by checking
# if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX
# uses to track the number of tokens per expert).
# In the non-PPLX mode, the input tokens are not padded: thus, the shape
# of the input is [total_num_tokens, hidden_size]. The input and output
# require shuffling by a_map and c_map such that the tokens assigned to
# each expert are contiguous.
# In the PPLX mode, the input tokens are padded per expert to ensure that
# the PPLX dispatch and combine functions work correctly: thus, the shape
# of the input is [num_experts, max_num_tokens_per_expert, hidden_size].
# The PPLX input and output require no shuffling by a_map and c_map since
# their tokens are already contiguous for each expert as a result of
# the dispatch function.
is_pplx = expert_num_tokens is not None
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
a1q = hidden_states
M = a1q.shape[0] # no pplx
padded_M = a1q.shape[1] # pplx
_, K, N = w2.shape
device = a1q.device
assert w1_scale is not None
assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1"
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2"
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim(
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[2], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert self.ab_strides1.shape[0] == w1.shape[
0], "AB Strides 1 expert number mismatch"
assert self.c_strides1.shape[0] == w1.shape[
0], "C Strides 1 expert number mismatch"
assert self.ab_strides2.shape[0] == w2.shape[
0], "AB Strides 2 expert number mismatch"
assert self.c_strides2.shape[0] == w2.shape[
0], "C Strides 2 expert number mismatch"
assert self.out_dtype in [torch.half,
torch.bfloat16], "Invalid output dtype"
assert w1.shape[2] == K
assert global_num_experts != -1
assert a1q_scale is not None
M = a1q.shape[0]
_, N, K = w2.shape # because w1 + w2 are transposed
device = a1q.device
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
expert_map[topk_ids], -1)
else:
local_topk_ids = topk_ids
assert w1.shape[1] == K
assert global_num_experts != -1
assert a1q_scale is not None
topk = local_topk_ids.shape[1]
local_E = w1.shape[0]
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
expert_map[topk_ids], -1)
else:
local_topk_ids = topk_ids
if is_pplx:
expert_offsets = torch.empty((local_E),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((local_E, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((local_E, 3),
dtype=torch.int32,
device=device)
topk = local_topk_ids.shape[1]
ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
local_E, padded_M, N, K)
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
w1_scale = w1_scale.reshape(w1_scale.shape[0], -1)
w2_scale = w2_scale.reshape(w2_scale.shape[0], -1)
a1q = a1q.reshape(-1, a1q.shape[2])
a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous()
else:
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
@@ -149,50 +148,130 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.shape[0], ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.shape[0], ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.shape[0], ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.shape[0], ),
K,
device=device,
dtype=torch.int64)
if is_pplx:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
else:
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale,
expert_offsets[:-1], problem_sizes1,
self.ab_strides1, self.ab_strides1, self.c_strides1)
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch)
self.activation(activation, c2, c1)
activation_callable(c2, c1)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
if expert_map is not None:
c3.fill_(0)
if expert_map is not None:
c3.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale,
expert_offsets[:-1], problem_sizes2,
self.ab_strides2, self.ab_strides2, self.c_strides2)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch)
c3 = c3[c_map]
return c3
if is_pplx:
return c3.reshape(local_E, padded_M, K)
else:
return c3[c_map].view(M, topk, K)
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
):
super().__init__()
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.per_act_token = per_act_token
self.per_out_ch = per_out_ch
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
padded_M = aq.shape[1]
workspace1 = self.max_experts_per_worker * padded_M * max(N, K)
workspace2 = self.max_experts_per_worker * padded_M * (N // 2)
return (workspace1, workspace2, self.out_dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o)
return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2,
expert_num_tokens, self.out_dtype,
self.per_act_token, self.per_out_ch)
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@@ -207,25 +286,17 @@ def cutlass_moe_fp8(
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.dtype): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
@@ -233,24 +304,27 @@ def cutlass_moe_fp8(
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.shape[0]
out_dtype = a.dtype
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
per_channel_quant=per_act_token,
quant_dtype=torch.float8_e4m3fn,
per_channel_quant=per_act_token,
),
CutlassExpertsFp8(
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
out_dtype,
max_experts_per_worker=global_num_experts,
out_dtype=out_dtype,
per_act_token=per_act_token,
per_out_ch=per_out_ch,
),
)
@@ -260,9 +334,12 @@ def cutlass_moe_fp8(
w2_q,
topk_weights,
topk_ids,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
False,
activation,
global_num_experts if global_num_experts != -1 else w1_q.size(0),
expert_map,
w1_scale,
w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,