[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -9,6 +9,8 @@ import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
@@ -143,10 +145,16 @@ def pplx_cutlass_moe(
device="cuda",
dtype=torch.int64)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, per_act_token, per_out_ch,
ab_strides1, ab_strides2, c_strides1,
c_strides2)
experts = CutlassBatchedExpertsFp8(
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
ab_strides2, c_strides1, c_strides2,
fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank]))
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
@@ -167,10 +175,7 @@ def pplx_cutlass_moe(
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank])
)
torch.cuda.synchronize()