[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import dataclasses
from math import prod
from typing import Optional
@@ -9,6 +10,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8, run_cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
@@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def slice_experts():
slice_params = [
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
"c_strides2", "w1_scale", "w2_scale"
"c_strides2"
]
full_tensors = {
k: v
@@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
if k in slice_params and k in cutlass_moe_kwargs
}
quant_config = cutlass_moe_kwargs["quant_config"]
for i in range(0, num_experts, num_local_experts):
s, e = i, i + num_local_experts
@@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
for k, t in full_tensors.items():
cutlass_moe_kwargs[k] = t[s:e]
new_quant_config = copy.deepcopy(quant_config)
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
cutlass_moe_kwargs["quant_config"] = new_quant_config
yield cutlass_moe_kwargs
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
@@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([
t is None for t in [
@@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
]
])
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=moe_tensors.w1_scale,
w2_scale=moe_tensors.w2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
# Set to moe_tensors.a_scale iff static scales + per tensor.
# This is not currently being tested.
a1_scale=None,
)
kwargs = {
'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
'quant_config': quant_config,
}
num_experts = moe_tensors.w1.size(0)
@@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph(
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
if ep_size is not None:
assert e % ep_size == 0, "Cannot distribute experts evenly"
number_local_experts = e // ep_size
else:
number_local_experts = None
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
number_local_experts)
per_out_ch, number_local_experts)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
@@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph(
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token)
per_act_token, per_out_ch)
torch.cuda.synchronize()
graph.replay()