[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045)
Signed-off-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
@@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_permute, moe_unpermute)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||
_fp8_quantize,
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
@@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
use_batched_format: bool,
|
||||
topk_weights: Optional[torch.Tensor],
|
||||
):
|
||||
a1q = hidden_states
|
||||
|
||||
@@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
|
||||
topk = local_topk_ids.size(1)
|
||||
local_E = w1.size(0)
|
||||
|
||||
if use_batched_format:
|
||||
mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
||||
act_out = _resize_cache(workspace2, (local_E * padded_M, N))
|
||||
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||
(local_E * padded_M, N))
|
||||
mm2_out = _resize_cache(workspace2, (local_E * padded_M, K))
|
||||
else:
|
||||
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
|
||||
(M * topk, K))
|
||||
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
|
||||
act_out = _resize_cache(workspace2, (M * topk, N))
|
||||
# original workspace are based on input hidden_states dtype (bf16)
|
||||
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||
(M * topk, N))
|
||||
mm2_out = _resize_cache(workspace2, (M * topk, K))
|
||||
|
||||
if use_batched_format:
|
||||
assert expert_num_tokens is not None
|
||||
|
||||
@@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
|
||||
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
|
||||
a1q = a1q.reshape(-1, a1q.size(2))
|
||||
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
|
||||
|
||||
# c3x get_group_gemm_starts expects int64 to avoid overflow
|
||||
# during offset calculations
|
||||
expert_offsets = expert_offsets.to(torch.int64)
|
||||
else:
|
||||
expert_offsets = torch.empty((global_num_experts + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes1 = torch.empty((global_num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# With expert_map each Rank processes only a subset of experts. As
|
||||
# a result not all of a_map and c2 tensors are filled. We fill it
|
||||
# zeros for correctness.
|
||||
if expert_map is not None:
|
||||
a_map = torch.zeros((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
else:
|
||||
a_map = torch.empty((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
c_map = torch.empty((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
|
||||
problem_sizes1, problem_sizes2, a_map,
|
||||
c_map, global_num_experts, N, K)
|
||||
|
||||
a1q = _fp8_perm(a1q, a_map)
|
||||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
||||
num_expert = global_num_experts if expert_map is None \
|
||||
else expert_map.size(0)
|
||||
# permuted a1q reuses workspace2
|
||||
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
topk_ids,
|
||||
num_expert,
|
||||
local_E,
|
||||
expert_map,
|
||||
permuted_hidden_states=a1q_perm)
|
||||
expert_offsets = expert_offsets[:-1]
|
||||
|
||||
ab_strides1 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((w1.size(0), ),
|
||||
2 * N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((w1.size(0), ),
|
||||
N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
if use_batched_format:
|
||||
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
||||
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
|
||||
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
|
||||
else:
|
||||
c1 = _resize_cache(workspace13, (M * topk, N * 2))
|
||||
c2 = _resize_cache(workspace2, (M * topk, N))
|
||||
c3 = _resize_cache(workspace13, (M * topk, K))
|
||||
ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1,
|
||||
problem_sizes2,
|
||||
global_num_experts, N, K)
|
||||
|
||||
if not per_act_token and (expert_map is not None or use_batched_format):
|
||||
# this is necessary to avoid imprecise scale calculation caused by
|
||||
# random data in the unused workspace. The workspace is unused when
|
||||
# this rank handles only partial tokens, or when it is batched .
|
||||
c1.fill_(0)
|
||||
mm1_out.fill_(0)
|
||||
|
||||
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
||||
ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
||||
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
|
||||
per_act_token, per_out_ch)
|
||||
|
||||
activation_callable(c2, c1)
|
||||
activation_callable(act_out, mm1_out)
|
||||
|
||||
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||
act_out,
|
||||
a2_scale,
|
||||
use_per_token_if_dynamic=per_act_token,
|
||||
output=quant_out)
|
||||
|
||||
if expert_map is not None:
|
||||
c3.fill_(0)
|
||||
mm2_out.fill_(0)
|
||||
|
||||
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
|
||||
ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets,
|
||||
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
|
||||
per_act_token, per_out_ch)
|
||||
|
||||
if use_batched_format:
|
||||
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True)
|
||||
output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True)
|
||||
else:
|
||||
# We can't do this inplace because output may point to the same tensor
|
||||
# as c3.
|
||||
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
|
||||
# for non-chunking mode the output is resized from workspace13
|
||||
# so we need to make sure mm2_out uses workspace2.
|
||||
moe_unpermute(out=output,
|
||||
permuted_hidden_states=mm2_out,
|
||||
topk_weights=topk_weights,
|
||||
inv_permuted_idx=inv_perm)
|
||||
|
||||
|
||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
block_shape=block_shape,
|
||||
))
|
||||
self.out_dtype = out_dtype
|
||||
self.ab_strides1 = ab_strides1
|
||||
self.ab_strides2 = ab_strides2
|
||||
self.c_strides1 = c_strides1
|
||||
self.c_strides2 = c_strides2
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
@@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
|
||||
self.c_strides2, workspace13, workspace2, expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
use_batched_format)
|
||||
use_batched_format, topk_weights)
|
||||
|
||||
|
||||
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
@@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
@@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# topk weights and reduction are fused in moe_unpermute cuda kernel
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, N // 2)
|
||||
output = (M * topk, K)
|
||||
workspace2 = (M * topk, max(N // 2, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
@@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
block_shape,
|
||||
)
|
||||
assert max_experts_per_worker > 0
|
||||
@@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
assert num_dp is not None
|
||||
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||
max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||
max(N // 2, K))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
@@ -392,6 +416,10 @@ def cutlass_moe_fp8(
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
per_act_token: Optional[bool] = None,
|
||||
activation: str = "silu",
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
@@ -419,6 +447,17 @@ def cutlass_moe_fp8(
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides1 (torch.Tensor): The output strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides2 (torch.Tensor): The output strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||
per-tensor.
|
||||
- activation (str): The activation function to use.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
@@ -450,6 +489,10 @@ def cutlass_moe_fp8(
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user