[Kernels] MoE refactor (#19636)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
bnellnm
2025-07-02 09:08:27 -04:00
committed by GitHub
parent b95877509b
commit c1909e7e8c
36 changed files with 2698 additions and 1584 deletions

View File

@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
@@ -202,26 +203,47 @@ def run_cutlass_moe_fp8(
# TODO (bnell): split class batched vs. non-batched?
# maybe remove need for passing aq to workspace_shapes
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
use_batched_format: bool = False,
):
super().__init__()
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.per_act_token = per_act_token
self.per_out_ch = per_out_ch
self.use_batched_format = use_batched_format
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
else:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return not self.use_batched_format
def supports_expert_map(self) -> bool:
return not self.use_batched_format
def workspace_shapes(
self,
a: torch.Tensor,
@@ -245,7 +267,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M * topk, K)
return (workspace1, workspace2, output, self.out_dtype)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(
self,
@@ -270,13 +293,14 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o)
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2,
expert_num_tokens, self.out_dtype,
self.per_act_token, self.per_out_ch,
self.use_batched_format)
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
def cutlass_moe_fp8(
@@ -287,6 +311,7 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
per_act_token: bool,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
@@ -330,22 +355,18 @@ def cutlass_moe_fp8(
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.size(0)
out_dtype = a.dtype
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
0)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
quant_dtype=torch.float8_e4m3fn,
per_channel_quant=per_act_token,
),
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
max_experts_per_worker=global_num_experts,
out_dtype=out_dtype,
per_act_token=per_act_token,
per_out_ch=per_out_ch,
max_experts_per_worker=num_experts,
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
use_batched_format=False,
),
)
@@ -358,7 +379,7 @@ def cutlass_moe_fp8(
topk_ids,
False,
activation,
global_num_experts if global_num_experts != -1 else w1_q.size(0),
num_experts,
expert_map,
w1_scale,
w2_scale,