[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -9,6 +9,8 @@ import torch
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
@@ -143,10 +145,16 @@ def pplx_cutlass_moe(
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
||||
out_dtype, per_act_token, per_out_ch,
|
||||
ab_strides1, ab_strides2, c_strides1,
|
||||
c_strides2)
|
||||
experts = CutlassBatchedExpertsFp8(
|
||||
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
|
||||
ab_strides2, c_strides1, c_strides2,
|
||||
fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||
if per_act_token else a1_scale[rank]))
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@@ -167,10 +175,7 @@ def pplx_cutlass_moe(
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None, #TODO
|
||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||
if per_act_token else a1_scale[rank])
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user