[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
bnellnm
2025-08-15 14:46:00 -04:00
committed by GitHub
parent 48b01fd4d4
commit 8ad7285ea2
54 changed files with 2010 additions and 1293 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels."""
from typing import Any, Callable, Optional
from typing import Callable, Optional
import torch
@@ -12,11 +12,10 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache,
extract_required_args)
_resize_cache)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@@ -213,19 +212,14 @@ def run_cutlass_moe_fp8(
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# TODO (bnell): split class batched vs. non-batched?
# maybe remove need for passing aq to workspace_shapes
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
):
super().__init__(
FusedMoEQuantConfig(
@@ -234,33 +228,84 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
assert max_experts_per_worker > 0
assert not use_batched_format or num_dispatchers is not None
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = self.activation_formats[
0] == mk.FusedMoEActivationFormat.BatchedExperts
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format)
class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
else:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return not self.use_batched_format
return True
def supports_expert_map(self) -> bool:
return not self.use_batched_format
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return True
def workspace_shapes(
self,
@@ -274,54 +319,69 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.size(1)
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
(N // 2))
output = (self.max_experts_per_worker, padded_M, K)
else:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
output = (M * topk, K)
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
output = (M * topk, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
activation_callable = lambda o, i: self.activation(activation, o, i)
def __init__(
self,
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
padded_M = aq.size(1)
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def cutlass_moe_fp8(
@@ -387,11 +447,9 @@ def cutlass_moe_fp8(
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
max_experts_per_worker=num_experts,
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
use_batched_format=False,
),
)
@@ -476,8 +534,9 @@ def run_cutlass_moe_fp4(
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
" between weights.")
assert (e_w1 == e_w2
and e_w1 == e), ("Number of experts must match",
f" between weights. {e_w1}, {e_w2}, {e}")
assert (k_a == half_k_w1 * 2
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
@@ -554,6 +613,10 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token_quant: bool,
@@ -562,8 +625,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
use_batched_format: bool = False,
):
super().__init__(
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
FusedMoEQuantConfig(
quant_dtype=torch.uint8,
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
@@ -572,6 +639,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
# TODO(bnell): put this stuff into quant config?
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
@property
def activation_formats(
self
@@ -590,8 +663,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
@@ -620,34 +692,42 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
required_keys = [
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
"e", "device"
]
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
device) = extract_required_args(extra_expert_args, required_keys)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor,
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
n = w2.shape[2] * 2
run_cutlass_moe_fp4(
output=output,
a=hidden_states,
a1_gscale=a1_gscale,
a1_gscale=self.a1_gscale,
w1_fp4=w1,
w1_blockscale=w1_scale,
w1_alphas=g1_alphas,
a2_gscale=a2_gscale,
w1_alphas=self.g1_alphas,
a2_gscale=self.a2_gscale,
w2_fp4=w2,
w2_blockscale=w2_scale,
w2_alphas=g2_alphas,
w2_alphas=self.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace13=workspace13,
@@ -656,7 +736,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
n=n,
k=k,
e=e,
device=device,
device=hidden_states.device,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@@ -677,7 +757,6 @@ def cutlass_moe_fp4(
n: int,
k: int,
e: int,
device: torch.device,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False) -> torch.Tensor:
assert expert_map is None, ("Expert Parallelism / expert_map "
@@ -686,6 +765,10 @@ def cutlass_moe_fp4(
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
g1_alphas,
g2_alphas,
a1_gscale,
a2_gscale,
max_experts_per_worker=e,
out_dtype=a.dtype,
per_act_token_quant=False,
@@ -693,29 +776,7 @@ def cutlass_moe_fp4(
use_batched_format=False,
),
)
extra_expert_args = {
'g1_alphas': g1_alphas,
'g2_alphas': g2_alphas,
'a1_gscale': a1_gscale,
'a2_gscale': a2_gscale,
'm': m,
'n': n,
'k': k,
'e': e,
'device': device,
}
# NVFP4 requires two levels of quantization, which involves computing some
# scaling factors dynamically. This makes it incompatible with the typical
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
# MoE body.
extra_prepare_args = {
'skip_quant': True,
}
# Similar reason as above.
extra_finalize_args = {
'skip_weight_reduce': True,
}
return fn(
hidden_states=a,
w1=w1_fp4,
@@ -731,9 +792,6 @@ def cutlass_moe_fp4(
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
@@ -824,16 +882,6 @@ def run_cutlass_block_scaled_fused_experts(
k = w1_q.size(1)
n = w2_q.size(1)
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device="cuda")
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
topk = topk_ids.size(1)
a_q, a1_scale = _fp8_quantize(a,
@@ -842,6 +890,16 @@ def run_cutlass_block_scaled_fused_experts(
block_shape=[128, 128])
device = a_q.device
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)