[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (expert_info, make_fused_experts,
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .parallel_utils import ProcessGroupInfo
@@ -40,7 +40,7 @@ class Config:
E: int
topks: Union[list[int], int]
dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig]
quant_config: Optional[TestMoEQuantConfig]
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
@@ -52,7 +52,7 @@ class Config:
def __post_init__(self):
if self.quant_config is None:
self.quant_config = FusedMoEQuantConfig()
self.quant_config = TestMoEQuantConfig(None, False, False, None)
def describe(self) -> str:
s = ""
@@ -275,21 +275,19 @@ class WeightTensors:
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device())
self.w2 = self.w2.to(device=torch.cuda.current_device())
device = torch.cuda.current_device()
self.w1 = self.w1.to(device=device)
self.w2 = self.w2.to(device=device)
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
self.w1_scale = self.w1_scale.to(
device=torch.cuda.current_device())
self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device())
if self.w1_scale is not None:
self.w1_scale = self.w1_scale.to(device=device)
if self.w2_scale is not None:
self.w2_scale = self.w2_scale.to(device=device)
if self.w1_gs is not None:
assert self.w2_gs is not None
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
self.w1_gs = self.w1_gs.to(device=device)
if self.w2_gs is not None:
self.w2_gs = self.w2_gs.to(device=device)
def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors":
@@ -297,20 +295,12 @@ class WeightTensors:
e = s + num_local_experts
w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :]
w1_scale, w2_scale = (None, None)
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
w1_scale = self.w1_scale[s:e, :, :]
w2_scale = self.w2_scale[s:e, :, :]
w1_gs = self.w1_gs
w2_gs = self.w2_gs
if w1_gs is not None:
assert w2_gs is not None
w1_gs = w1_gs[s:e]
w2_gs = w2_gs[s:e]
w1_scale = self.w1_scale[
s:e, :, :] if self.w1_scale is not None else None
w2_scale = self.w2_scale[
s:e, :, :] if self.w2_scale is not None else None
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
@@ -323,7 +313,8 @@ class WeightTensors:
in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_out_ch_quant,
per_out_ch_quant=config.
is_per_act_token_quant, # or config.is_per_out_ch_quant
)
return WeightTensors(w1=w1,
w2=w2,
@@ -342,8 +333,6 @@ class RankTensors:
topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor]
quant_config: Optional[FusedMoEQuantConfig]
def describe(self):
s = ""
s += "== Rank Tensors: \n"
@@ -426,7 +415,6 @@ class RankTensors:
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
quant_config=config.quant_config,
)
@@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
and config.supports_apply_weight_on_input())
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_modular_kernel(
config: Config,
vllm_config: VllmConfig,
weights: WeightTensors,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel:
def next_power_of_2(x):
@@ -548,20 +542,20 @@ def make_modular_kernel(
num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
quant_config=config.quant_config,
max_num_tokens=next_power_of_2(config.M),
)
# make modular kernel
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
config.all2all_backend(), moe)
config.all2all_backend(), moe,
quant_config)
fused_experts = make_fused_experts(
config.fused_experts_type,
moe,
quant_config,
prepare_finalize.num_dispatchers(),
weights.w1_gs,
weights.w2_gs,
config.N,
)
modular_kernel = mk.FusedMoEModularKernel(
@@ -583,12 +577,38 @@ def run_modular_kernel(
# weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config, weights)
if config.quant_dtype == "nvfp4":
gscale = _make_gscale(config.num_local_experts)
else:
gscale = None
quant_config = FusedMoEQuantConfig.make(
config.quant_dtype,
w1_scale=rank_weights.w1_scale,
w2_scale=rank_weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
g1_alphas=(1 / rank_weights.w1_gs)
if rank_weights.w1_gs is not None else None,
g2_alphas=(1 / rank_weights.w2_gs)
if rank_weights.w2_gs is not None else None,
a1_gscale=gscale,
a2_gscale=gscale,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_act_token_quant,
per_out_ch_quant=config.is_per_out_ch_quant,
)
mk = make_modular_kernel(config, vllm_config, quant_config)
# impls might update the tensor in place
hidden_states = rank_tensors.hidden_states.clone()
topk_ids = rank_tensors.topk_ids.to(
mk.prepare_finalize.topk_indices_dtype())
mk_kwargs = {
"hidden_states":
rank_tensors.hidden_states.clone(
), # impls might update the tensor in place
hidden_states,
"w1":
rank_weights.w1,
"w2":
@@ -596,15 +616,9 @@ def run_modular_kernel(
"topk_weights":
rank_tensors.topk_weights,
"topk_ids":
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
topk_ids,
"expert_map":
rank_tensors.expert_map,
"w1_scale":
rank_weights.w1_scale,
"w2_scale":
rank_weights.w2_scale,
"a1_scale":
rank_tensors.hidden_states_scale,
"global_num_experts":
config.E,
"apply_router_weight_on_input":

View File

@@ -10,7 +10,8 @@ import torch
from tqdm import tqdm
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG)
from vllm.platforms import current_platform
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
@@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str):
quant_config_dict = config_dict['quant_config']
del config_dict['quant_config']
if quant_config_dict is None:
quant_config = FusedMoEQuantConfig(None)
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
quant_config_dict = asdict(quant_config)
config_dict |= quant_config_dict

View File

@@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@dataclass
class TestMoEQuantConfig:
quant_dtype: Union[torch.dtype, str, None]
per_out_ch_quant: bool
per_act_token_quant: bool
block_shape: Optional[list[int]]
@dataclass
class PrepareFinalizeInfo:
activation_format: mk.FusedMoEActivationFormat
@@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
]
common_float_and_int_types = common_float_types + [torch.int8]
nv_fp4_types = ["nvfp4"]
nvfp4_types = ["nvfp4"]
fp8_types = [torch.float8_e4m3fn]
@@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe()
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
standard_format,
nv_fp4_types,
nvfp4_types,
blocked_quantization_support=True,
backend=None,
force_multigpu=True,
@@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe()
register_experts(
FlashInferExperts,
standard_format,
nv_fp4_types,
nvfp4_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
@@ -306,39 +314,39 @@ if cutlass_fp4_supported():
register_experts(
CutlassExpertsFp4,
standard_format,
nv_fp4_types,
nvfp4_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False,
)
MK_QUANT_CONFIGS = [
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
None,
# per-channel / per-column weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
# per-channel / per-column weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
# per-tensor weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
# per-tensor weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
# block-quantized weights and 128 block per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
# TODO (varun) : Should we test the following combinations ?
# block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations
@@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [
FusedMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
]
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: Optional[str],
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
moe, quant_config)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp=moe.moe_parallel_config.dp_size > 1,
a1_gscale=_make_gscale(moe.num_local_experts),
)
use_dp=moe.moe_parallel_config.dp_size > 1)
else:
return MoEPrepareAndFinalizeNoEP()
@@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
return t[s:e]
def make_cutlass_strides(
e: int,
n: int,
k: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
return ab_strides1, ab_strides2, c_strides1, c_strides2
def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
num_dispatchers: int,
w1_gs: Optional[torch.Tensor],
w2_gs: Optional[torch.Tensor],
N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
"quant_config": quant_config,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
if fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif fused_experts_type == BatchedTritonExperts:
@@ -422,8 +429,8 @@ def make_fused_experts(
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
print("Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
@@ -437,62 +444,50 @@ def make_fused_experts(
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif fused_experts_type == CutlassExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = {
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
"ab_strides1": strides[0],
"ab_strides2": strides[1],
"c_strides1": strides[2],
"c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
elif fused_experts_type == CutlassBatchedExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = {
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
"ab_strides1": strides[0],
"ab_strides2": strides[1],
"c_strides1": strides[2],
"c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
experts = CutlassBatchedExpertsFp8(**kwargs)
elif fused_experts_type == CutlassExpertsFp4:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
}
"out_dtype": moe.in_dtype,
} | quant_kwargs
print(f"Making CutlassExpertsFp4 {kwargs} ...")
experts = CutlassExpertsFp4(**kwargs)
elif fused_experts_type == FlashInferExperts:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"out_dtype": moe.in_dtype,
"quant_dtype": "nvfp4",
"ep_rank": moe.ep_rank,
"ep_size": moe.ep_size,
"tp_rank": moe.tp_rank,
"tp_size": moe.tp_size,
}
} | quant_kwargs
print(f"Making FlashInferExperts {kwargs} ...")
experts = FlashInferExperts(**kwargs)
else:
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
return experts