[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (expert_info, make_fused_experts,
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .parallel_utils import ProcessGroupInfo
@@ -40,7 +40,7 @@ class Config:
E: int
topks: Union[list[int], int]
dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig]
quant_config: Optional[TestMoEQuantConfig]
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
@@ -52,7 +52,7 @@ class Config:
def __post_init__(self):
if self.quant_config is None:
self.quant_config = FusedMoEQuantConfig()
self.quant_config = TestMoEQuantConfig(None, False, False, None)
def describe(self) -> str:
s = ""
@@ -275,21 +275,19 @@ class WeightTensors:
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device())
self.w2 = self.w2.to(device=torch.cuda.current_device())
device = torch.cuda.current_device()
self.w1 = self.w1.to(device=device)
self.w2 = self.w2.to(device=device)
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
self.w1_scale = self.w1_scale.to(
device=torch.cuda.current_device())
self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device())
if self.w1_scale is not None:
self.w1_scale = self.w1_scale.to(device=device)
if self.w2_scale is not None:
self.w2_scale = self.w2_scale.to(device=device)
if self.w1_gs is not None:
assert self.w2_gs is not None
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
self.w1_gs = self.w1_gs.to(device=device)
if self.w2_gs is not None:
self.w2_gs = self.w2_gs.to(device=device)
def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors":
@@ -297,20 +295,12 @@ class WeightTensors:
e = s + num_local_experts
w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :]
w1_scale, w2_scale = (None, None)
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
w1_scale = self.w1_scale[s:e, :, :]
w2_scale = self.w2_scale[s:e, :, :]
w1_gs = self.w1_gs
w2_gs = self.w2_gs
if w1_gs is not None:
assert w2_gs is not None
w1_gs = w1_gs[s:e]
w2_gs = w2_gs[s:e]
w1_scale = self.w1_scale[
s:e, :, :] if self.w1_scale is not None else None
w2_scale = self.w2_scale[
s:e, :, :] if self.w2_scale is not None else None
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
@@ -323,7 +313,8 @@ class WeightTensors:
in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_out_ch_quant,
per_out_ch_quant=config.
is_per_act_token_quant, # or config.is_per_out_ch_quant
)
return WeightTensors(w1=w1,
w2=w2,
@@ -342,8 +333,6 @@ class RankTensors:
topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor]
quant_config: Optional[FusedMoEQuantConfig]
def describe(self):
s = ""
s += "== Rank Tensors: \n"
@@ -426,7 +415,6 @@ class RankTensors:
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
quant_config=config.quant_config,
)
@@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
and config.supports_apply_weight_on_input())
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_modular_kernel(
config: Config,
vllm_config: VllmConfig,
weights: WeightTensors,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel:
def next_power_of_2(x):
@@ -548,20 +542,20 @@ def make_modular_kernel(
num_local_experts=config.num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
quant_config=config.quant_config,
max_num_tokens=next_power_of_2(config.M),
)
# make modular kernel
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
config.all2all_backend(), moe)
config.all2all_backend(), moe,
quant_config)
fused_experts = make_fused_experts(
config.fused_experts_type,
moe,
quant_config,
prepare_finalize.num_dispatchers(),
weights.w1_gs,
weights.w2_gs,
config.N,
)
modular_kernel = mk.FusedMoEModularKernel(
@@ -583,12 +577,38 @@ def run_modular_kernel(
# weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config, weights)
if config.quant_dtype == "nvfp4":
gscale = _make_gscale(config.num_local_experts)
else:
gscale = None
quant_config = FusedMoEQuantConfig.make(
config.quant_dtype,
w1_scale=rank_weights.w1_scale,
w2_scale=rank_weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
g1_alphas=(1 / rank_weights.w1_gs)
if rank_weights.w1_gs is not None else None,
g2_alphas=(1 / rank_weights.w2_gs)
if rank_weights.w2_gs is not None else None,
a1_gscale=gscale,
a2_gscale=gscale,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_act_token_quant,
per_out_ch_quant=config.is_per_out_ch_quant,
)
mk = make_modular_kernel(config, vllm_config, quant_config)
# impls might update the tensor in place
hidden_states = rank_tensors.hidden_states.clone()
topk_ids = rank_tensors.topk_ids.to(
mk.prepare_finalize.topk_indices_dtype())
mk_kwargs = {
"hidden_states":
rank_tensors.hidden_states.clone(
), # impls might update the tensor in place
hidden_states,
"w1":
rank_weights.w1,
"w2":
@@ -596,15 +616,9 @@ def run_modular_kernel(
"topk_weights":
rank_tensors.topk_weights,
"topk_ids":
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
topk_ids,
"expert_map":
rank_tensors.expert_map,
"w1_scale":
rank_weights.w1_scale,
"w2_scale":
rank_weights.w2_scale,
"a1_scale":
rank_tensors.hidden_states_scale,
"global_num_experts":
config.E,
"apply_router_weight_on_input":

View File

@@ -10,7 +10,8 @@ import torch
from tqdm import tqdm
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG)
from vllm.platforms import current_platform
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
@@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str):
quant_config_dict = config_dict['quant_config']
del config_dict['quant_config']
if quant_config_dict is None:
quant_config = FusedMoEQuantConfig(None)
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
quant_config_dict = asdict(quant_config)
config_dict |= quant_config_dict

View File

@@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@dataclass
class TestMoEQuantConfig:
quant_dtype: Union[torch.dtype, str, None]
per_out_ch_quant: bool
per_act_token_quant: bool
block_shape: Optional[list[int]]
@dataclass
class PrepareFinalizeInfo:
activation_format: mk.FusedMoEActivationFormat
@@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
]
common_float_and_int_types = common_float_types + [torch.int8]
nv_fp4_types = ["nvfp4"]
nvfp4_types = ["nvfp4"]
fp8_types = [torch.float8_e4m3fn]
@@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe()
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
standard_format,
nv_fp4_types,
nvfp4_types,
blocked_quantization_support=True,
backend=None,
force_multigpu=True,
@@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe()
register_experts(
FlashInferExperts,
standard_format,
nv_fp4_types,
nvfp4_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
@@ -306,39 +314,39 @@ if cutlass_fp4_supported():
register_experts(
CutlassExpertsFp4,
standard_format,
nv_fp4_types,
nvfp4_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False,
)
MK_QUANT_CONFIGS = [
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
None,
# per-channel / per-column weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
# per-channel / per-column weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
# per-tensor weights and per-tensor activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
# per-tensor weights and per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
# block-quantized weights and 128 block per-token activations
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
# TODO (varun) : Should we test the following combinations ?
# block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations
@@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [
FusedMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
]
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: Optional[str],
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
moe, quant_config)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp=moe.moe_parallel_config.dp_size > 1,
a1_gscale=_make_gscale(moe.num_local_experts),
)
use_dp=moe.moe_parallel_config.dp_size > 1)
else:
return MoEPrepareAndFinalizeNoEP()
@@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
return t[s:e]
def make_cutlass_strides(
e: int,
n: int,
k: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
return ab_strides1, ab_strides2, c_strides1, c_strides2
def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
num_dispatchers: int,
w1_gs: Optional[torch.Tensor],
w2_gs: Optional[torch.Tensor],
N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
"quant_config": quant_config,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
if fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif fused_experts_type == BatchedTritonExperts:
@@ -422,8 +429,8 @@ def make_fused_experts(
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
print("Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
@@ -437,62 +444,50 @@ def make_fused_experts(
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif fused_experts_type == CutlassExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = {
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
"ab_strides1": strides[0],
"ab_strides2": strides[1],
"c_strides1": strides[2],
"c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
elif fused_experts_type == CutlassBatchedExpertsFp8:
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
kwargs = {
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
"ab_strides1": strides[0],
"ab_strides2": strides[1],
"c_strides1": strides[2],
"c_strides2": strides[3],
} | quant_kwargs
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
experts = CutlassBatchedExpertsFp8(**kwargs)
elif fused_experts_type == CutlassExpertsFp4:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
}
"out_dtype": moe.in_dtype,
} | quant_kwargs
print(f"Making CutlassExpertsFp4 {kwargs} ...")
experts = CutlassExpertsFp4(**kwargs)
elif fused_experts_type == FlashInferExperts:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"out_dtype": moe.in_dtype,
"quant_dtype": "nvfp4",
"ep_rank": moe.ep_rank,
"ep_size": moe.ep_size,
"tp_rank": moe.tp_rank,
"tp_size": moe.tp_size,
}
} | quant_kwargs
print(f"Making FlashInferExperts {kwargs} ...")
experts = FlashInferExperts(**kwargs)
else:
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
return experts

View File

@@ -6,6 +6,8 @@ import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@@ -56,13 +58,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
rank=0,
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_s,
w2_scale=w2_s,
per_act_token_quant=False,
block_shape=BLOCK_SIZE,
)
# triton (reference)
triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=True,
per_act_token_quant=False,
block_shape=BLOCK_SIZE,
quant_config=quant_config,
)
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
@@ -73,8 +80,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
w1_scale=w1_s,
w2_scale=w2_s,
global_num_experts=E,
)
@@ -82,8 +87,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
deepgemm_experts = BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
block_shape=BLOCK_SIZE,
per_act_token_quant=False,
quant_config=quant_config,
)
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
@@ -94,8 +98,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
w1_scale=w1_s,
w2_scale=w2_s,
global_num_experts=E,
)

View File

@@ -140,7 +140,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
out_shape = (num_experts, max_tokens_per_expert, N)
@@ -250,7 +250,7 @@ def test_fused_moe_batched_experts(
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
if input_scales and quant_dtype is not None:

View File

@@ -4,7 +4,7 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
@@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
w1, w2, quant_config = make_test_quant_config(
E,
N,
K,
dtype,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size,
)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=block_size)
m_fused_moe = modular_triton_fused_moe(quant_config)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
@@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a,
w1,
w2,
w1_s,
w2_s,
quant_config.w1_scale,
quant_config.w2_scale,
topk_weights,
topk_ids,
block_size,
)
out = fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
out = fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config)
m_out = m_fused_moe(
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
)
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
tol = 0.035 if M < 40000 else 0.039
# 0.039 only needed for M >= 8192
tol = 0.035 if M < 8192 else 0.039
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
@@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_out_ch_quant=False,
block_shape=block_size,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and

View File

@@ -4,12 +4,12 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
@@ -50,7 +50,7 @@ MNK_FACTORS = [
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4096, 512),
(2048, 4096, 7168),
(2048, 4096, 4096),
]
E = [8, 24]
@@ -117,31 +117,28 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,
torch.int8,
per_act_token_quant=False,
block_shape=block_size)
w1, w2, quant_config = make_test_quant_config(
E,
N,
K,
dtype,
quant_dtype=torch.int8,
per_act_token_quant=False,
block_shape=block_size,
)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
out = fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
quant_config.w2_scale, score, topk,
block_size)
# Check results

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import dataclasses
from math import prod
from typing import Optional
@@ -9,6 +10,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8, run_cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
@@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def slice_experts():
slice_params = [
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
"c_strides2", "w1_scale", "w2_scale"
"c_strides2"
]
full_tensors = {
k: v
@@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
if k in slice_params and k in cutlass_moe_kwargs
}
quant_config = cutlass_moe_kwargs["quant_config"]
for i in range(0, num_experts, num_local_experts):
s, e = i, i + num_local_experts
@@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
for k, t in full_tensors.items():
cutlass_moe_kwargs[k] = t[s:e]
new_quant_config = copy.deepcopy(quant_config)
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
cutlass_moe_kwargs["quant_config"] = new_quant_config
yield cutlass_moe_kwargs
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
@@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([
t is None for t in [
@@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
]
])
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=moe_tensors.w1_scale,
w2_scale=moe_tensors.w2_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
# Set to moe_tensors.a_scale iff static scales + per tensor.
# This is not currently being tested.
a1_scale=None,
)
kwargs = {
'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
'quant_config': quant_config,
}
num_experts = moe_tensors.w1.size(0)
@@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph(
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
if ep_size is not None:
assert e % ep_size == 0, "Cannot distribute experts evenly"
number_local_experts = e // ep_size
else:
number_local_experts = None
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
number_local_experts)
per_out_ch, number_local_experts)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
@@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph(
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token)
per_act_token, per_out_ch)
torch.cuda.synchronize()
graph.replay()

View File

@@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
@@ -71,9 +73,12 @@ def make_block_quant_fp8_weights(
Return weights w1q, w2q, w1_scale, w2_scale
"""
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
_) = make_test_weights(e, n, k, torch.bfloat16,
_) = make_test_weights(e,
n,
k,
torch.bfloat16,
torch.float8_e4m3fn,
block_size)
block_shape=block_size)
return w1q, w2q, w1_scale, w2_scale
@@ -130,10 +135,11 @@ class TestTensors:
config=config)
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
max_tokens_per_rank: int, dp_size: int,
hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:
def make_ll_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None
@@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank,
num_dispatchers=pgi.world_size // dp_size,
block_shape=test_config.block_size,
per_act_token_quant=test_config.per_act_token_quant)
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
test_config: TestConfig) -> FusedMoEModularKernel:
def make_ht_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None
@@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype=q_dtype,
block_shape=test_config.block_size)
fused_experts = DeepGemmExperts()
fused_experts = DeepGemmExperts(quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int,
test_tensors: TestTensors) -> FusedMoEModularKernel:
def make_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config
@@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config)
test_config=test_config,
quant_config=quant_config)
else:
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
q_dtype, test_config)
mk = make_ht_modular_kernel(pg,
pgi,
dp_size,
num_local_experts,
q_dtype,
test_config,
quant_config=quant_config)
return mk
@@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
# Low-Latency kernels can't dispatch scales.
a1_scale=(None if test_config.low_latency else
test_tensors.rank_token_scales),
block_shape=test_config.block_size,
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg=pg,
pgi=pgi,
dp_size=dp_size,
num_local_experts=num_local_experts,
test_tensors=test_tensors)
# Low-Latency kernels can't dispatch scales.
a1_scale = (None
if test_config.low_latency else test_tensors.rank_token_scales)
test_tensors=test_tensors,
quant_config=quant_config)
out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1,
@@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=a1_scale,
a2_scale=None,
apply_router_weight_on_input=False)
return out
@@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, block_shape: list[int]):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
)
return fused_experts(
hidden_states=a,
w1=w1,
@@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
block_shape=block_shape,
quant_config=quant_config,
# Make sure this is set to False so we
# don't end up comparing the same implementation.
allow_deep_gemm=False)

View File

@@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@@ -129,11 +130,9 @@ def make_modular_kernel(
num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None
ht_args: Optional[DeepEPHTArgs] = None
ll_args: Optional[DeepEPLLArgs] = None
@@ -159,24 +158,14 @@ def make_modular_kernel(
num_dispatchers = pgi.world_size // dp_size
if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode"
assert not quant_config.per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
num_dispatchers=num_dispatchers,
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=False,
quant_config=quant_config,
)
else:
fused_experts = TritonExperts(
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=per_act_token_quant,
)
fused_experts = TritonExperts(quant_config=quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
@@ -217,11 +206,6 @@ def deep_ep_moe_impl(
if is_quantized:
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0)
@@ -236,6 +220,19 @@ def deep_ep_moe_impl(
rank_token_scales_chunk = rank_token_scales_chunk[
chunk_start:chunk_end]
quant_config = FusedMoEQuantConfig.make(
q_dtype,
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token_quant,
a1_scale=rank_token_scales_chunk,
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
out = mk.forward(hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
@@ -245,12 +242,6 @@ def deep_ep_moe_impl(
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=rank_token_scales_chunk,
a2_scale=None,
apply_router_weight_on_input=False)
if not skip_result_store:
@@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("m,n,k", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@requires_deep_ep
def test_deep_ep_moe(
dtype: torch.dtype,
mnk: tuple[int, int, int],
m: int,
n: int,
k: int,
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
@@ -424,7 +417,6 @@ def test_deep_ep_moe(
):
low_latency_mode = False
use_fp8_dispatch = False
m, n, k = mnk
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
@@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("m,n,k", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool):
def test_low_latency_deep_ep_moe(
dtype: torch.dtype,
m: int,
n: int,
k: int,
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool,
):
low_latency_mode = True
m, n, k = mnk
if (low_latency_mode
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):

View File

@@ -11,6 +11,8 @@ import math
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
)
# triton reference
out_triton = fused_experts(
hidden_states=tokens_bf16,
@@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
quant_config=quant_config,
allow_deep_gemm=False,
)
@@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
a1_scale=a1_scale,
block_shape=block_size,
quant_config=quant_config,
allow_deep_gemm=True,
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
# Note: W1 has shape (E, 2N, K), so N = 512
# can trigger the deepgemm path.
# Note: N <= 512 will disable the deepgemm path due to performance issues.
MNKs = [
(1024, 768, 128),
(1024, 768, 512),
@@ -144,15 +144,15 @@ TOPKS = [2, 6]
NUM_EXPERTS = [32]
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels")
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_DEEP_GEMM", "1")
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_fused_moe_mod = importlib.import_module(
"vllm.model_executor.layers.fused_moe.fused_moe")
@@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
_spy_deep_gemm_moe_fp8)
m, n, k = mnk
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")

View File

@@ -6,6 +6,8 @@ import pytest
import torch
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
@@ -145,6 +147,14 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
per_act_token_quant=False,
)
output = fused_experts(
td.hidden_states,
td.w13_quantized,
@@ -153,15 +163,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk_ids=topk_ids,
inplace=False,
activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e,
expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True,
quant_config=quant_config,
)
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
@@ -210,6 +215,14 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
per_act_token_quant=False,
)
output = fused_experts(
td.hidden_states,
td.w13_quantized,
@@ -218,15 +231,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
topk_ids=topk_ids,
inplace=False,
activation="silu",
use_fp8_w8a8=True,
per_channel_quant=False,
global_num_experts=e,
expert_map=None,
w1_scale=td.w13_weight_scale,
w2_scale=td.w2_weight_scale,
a1_scale=td.a1_scale,
a2_scale=td.a2_scale,
apply_router_weight_on_input=True,
quant_config=quant_config,
)
td.layer.dp_size = 1

View File

@@ -3,7 +3,7 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
@@ -41,7 +41,6 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
#@pytest.mark.parametrize("e", [128, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
@@ -56,16 +55,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
quant_blocksize = 16
(_, w1_q, w1_blockscale,
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_act_token_quant=False,
)
w1_q, w2_q, quant_config = make_test_quant_config(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None,
per_act_token_quant=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
@@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
topk,
renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
a1_gscale=a1_gs,
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
g2_alphas=(1 / w2_gs),
out_dtype=dtype,
quant_dtype="nvfp4",
))
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
)
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w1_scale=w1_blockscale,
w2=w2_q,
w2_scale=w2_blockscale,
a1_scale=a1_gs,
a2_scale=a2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
@@ -122,18 +102,18 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)

View File

@@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
@@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
quant_config = FusedMoEQuantConfig.make(
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2,
)
out_triton_monolithic = triton_kernel_moe_forward(
hidden_states=x_tri,
w1=w1_tri,
@@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
gating_output=exp_data_tri,
topk=topk,
renormalize=True,
w1_bias=w1_bias_tri,
w2_bias=w2_bias_tri,
w1_precision=pc1,
w2_precision=pc2,
quant_config=quant_config,
)
out_triton_monolithic = out_triton_monolithic[..., :K]
@@ -336,6 +341,13 @@ def batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
w1_precision=w1_precision,
w2_precision=w2_precision,
w1_bias=w1_bias,
w2_bias=w2_bias,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(
max_num_tokens,
@@ -344,19 +356,12 @@ def batched_moe(
rank=0,
),
BatchedOAITritonExperts(
None,
max_num_tokens=max_num_tokens,
num_dispatchers=1,
w1_precision=w1_precision,
w2_precision=w2_precision,
quant_config=quant_config,
),
)
extra_expert_args = {
"w1_bias": w1_bias,
"w2_bias": w2_bias,
}
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
return fused_experts(
@@ -365,7 +370,6 @@ def batched_moe(
w2,
topk_weight,
topk_ids,
extra_expert_args=extra_expert_args,
)

View File

@@ -12,7 +12,6 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@@ -22,7 +21,8 @@ from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
run_modular_kernel)
from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
expert_info)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config)
@@ -55,7 +55,7 @@ def rank_worker(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
cpu_group,
config: Config,
base_config: Config,
weights: WeightTensors,
verbose: bool,
):
@@ -63,42 +63,44 @@ def rank_worker(
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
if base_config.fused_moe_chunk_size is not None:
assert (
base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
# get weights to this device
weights.to_current_device()
Ms = config.Ms
Ms = base_config.Ms
assert isinstance(Ms, list)
TOPKs = config.topks
TOPKs = base_config.topks
assert isinstance(TOPKs, list)
exceptions = []
count = 0
for m, topk in product(Ms, TOPKs):
# override m and topk
config = copy.deepcopy(base_config)
config.Ms = m
config.topks = topk
try:
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
count = count + 1
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi)
rank_tensors = RankTensors.make(config, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
ref_out = reference_moe_impl(config, weights, rank_tensors)
if config.quant_dtype == "nvfp4":
atol = 1e-1
rtol = 1e-1
atol = 1e-1 if config.K < 4096 else 2e-1
rtol = 1e-1 if config.K < 4096 else 2e-1
else:
atol = 3e-2
rtol = 3e-2
@@ -132,7 +134,7 @@ Ms = [32, 64]
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
Ns = [2048]
Ns = [1024]
TOPKs = [4, 1]
Es = [32]
DTYPEs = [torch.bfloat16]
@@ -167,7 +169,7 @@ def is_nyi_config(config: Config) -> bool:
@meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[FusedMoEQuantConfig],
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
@@ -208,7 +210,7 @@ def test_modular_kernel_combinations_multigpu(
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[FusedMoEQuantConfig],
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):

View File

@@ -15,11 +15,14 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
@@ -187,14 +190,9 @@ def test_fused_moe(
#
# Setup test functions
#
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=None)
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
def m_fused_moe(
a: torch.Tensor,
@@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
else:
e_map = None
if weight_bits == 4:
quant_config_builder = int4_w4a16_moe_quant_config
else:
assert weight_bits == 8
quant_config_builder = int8_w8a16_moe_quant_config
quant_config = quant_config_builder(w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
with set_current_vllm_config(vllm_config):
triton_output = fused_moe(a,
w1_qweight,
@@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
score,
topk,
renormalize=False,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=e,
expert_map=e_map,
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
quant_config=quant_config)
torch_output = torch_moe(a,
w1_ref,
w2_ref,

View File

@@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
@@ -56,7 +57,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_act_token_quant=False,
per_out_ch_quant=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
@@ -73,18 +74,22 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
assert w1_blockscale is not None
assert w2_blockscale is not None
quant_config = nvfp4_moe_quant_config(
g1_alphas=(1 / w1_gs),
g2_alphas=(1 / w2_gs),
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
)
cutlass_output = cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_q,
w1_blockscale=w1_blockscale,
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
w2_fp4=w2_q,
w2_blockscale=w2_blockscale,
g2_alphas=(1 / w2_gs),
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=quant_config,
m=m,
n=n,
k=k,

View File

@@ -9,6 +9,8 @@ import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
@@ -143,10 +145,16 @@ def pplx_cutlass_moe(
device="cuda",
dtype=torch.int64)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, per_act_token, per_out_ch,
ab_strides1, ab_strides2, c_strides1,
c_strides2)
experts = CutlassBatchedExpertsFp8(
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
ab_strides2, c_strides1, c_strides2,
fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank]))
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
@@ -167,10 +175,7 @@ def pplx_cutlass_moe(
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank])
)
torch.cuda.synchronize()

View File

@@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
]
PPLX_COMBOS = [
# TODO: figure out why this fails, seems to be test problem
# TODO(bnell): figure out why this fails, seems to be test problem
#(1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
@@ -360,18 +360,18 @@ def pplx_prepare_finalize(
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
a_chunk,
a1_scale,
a2_scale,
chunk_topk_weight,
chunk_topk_ids,
num_experts,
None,
False,
FusedMoEQuantConfig(
FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant,
False,
block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=block_shape,
a1_scale=a1_scale,
a2_scale=a2_scale,
),
)
@@ -540,20 +540,6 @@ def pplx_moe(
topk_ids = topk_ids.to(dtype=torch.uint32)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk = chunk_by_rank(a, rank, world_size)
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
@@ -567,6 +553,28 @@ def pplx_moe(
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
)
experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=quant_config,
)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
@@ -585,10 +593,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
if use_cudagraphs:
@@ -605,10 +609,6 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
a1_scale=a1_scale_chunk,
a2_scale=a2_scale_chunk,
global_num_experts=num_experts)
torch.cuda.synchronize()
@@ -820,7 +820,7 @@ def test_pplx_moe_slow(
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
@@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
k,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_act_token_quant,
)
args["w1"] = w1
args["w2"] = w2

View File

@@ -7,10 +7,12 @@ import itertools
import pytest
import torch
from tests.kernels.moe.utils import fused_moe
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (9, 0):
@@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
score,
topk,
renormalize=False,
use_fp8_w8a8=True, # using fp8
per_channel_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
quant_config=fp8_w8a8_moe_quant_config(
per_act_token_quant=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=None, # Not using block quantization
),
)
# Check results

View File

@@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@@ -34,18 +35,22 @@ def triton_moe(
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_act_token_quant,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape)
quant_config=quant_config)
def batched_moe(
@@ -64,6 +69,16 @@ def batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
@@ -72,21 +87,11 @@ def batched_moe(
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
quant_config=quant_config,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def naive_batched_moe(
@@ -105,6 +110,16 @@ def naive_batched_moe(
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
quant_config = FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
@@ -113,21 +128,11 @@ def naive_batched_moe(
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
quant_config=quant_config,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
@@ -216,7 +221,7 @@ def make_test_weight(
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
@@ -228,7 +233,7 @@ def make_test_weight(
w_gs_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
@@ -258,16 +263,16 @@ def make_test_weights(
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
return (
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
per_out_ch_quant),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
per_out_ch_quant),
)
@@ -285,6 +290,76 @@ def per_token_cast_to_fp8(
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def make_test_quant_config(
e: int,
n: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype,
quant_dtype,
per_out_ch_quant=per_act_token_quant,
block_shape=block_shape,
)
# Hacky/trivial scales for nvfp4.
a1_gscale: Optional[torch.Tensor] = None
a2_gscale: Optional[torch.Tensor] = None
if quant_dtype == "nvfp4":
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a1_scale = a1_gscale
a2_scale = a2_gscale
else:
a1_scale = None
a2_scale = None
return w1, w2, FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_s,
w2_scale=w2_s,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
a1_scale=a1_scale,
a2_scale=a2_scale,
# TODO: make sure this is handled properly
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
renormalize: bool = False,
quant_config: Optional[FusedMoEQuantConfig] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=quant_config)
# CustomOp?
class BaselineMM(torch.nn.Module):