[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .mk_objects import (expert_info, make_fused_experts,
|
||||
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
|
||||
make_prepare_finalize, prepare_finalize_info)
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
|
||||
@@ -40,7 +40,7 @@ class Config:
|
||||
E: int
|
||||
topks: Union[list[int], int]
|
||||
dtype: torch.dtype
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
quant_config: Optional[TestMoEQuantConfig]
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
@@ -52,7 +52,7 @@ class Config:
|
||||
|
||||
def __post_init__(self):
|
||||
if self.quant_config is None:
|
||||
self.quant_config = FusedMoEQuantConfig()
|
||||
self.quant_config = TestMoEQuantConfig(None, False, False, None)
|
||||
|
||||
def describe(self) -> str:
|
||||
s = ""
|
||||
@@ -275,21 +275,19 @@ class WeightTensors:
|
||||
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
|
||||
|
||||
def to_current_device(self):
|
||||
self.w1 = self.w1.to(device=torch.cuda.current_device())
|
||||
self.w2 = self.w2.to(device=torch.cuda.current_device())
|
||||
device = torch.cuda.current_device()
|
||||
self.w1 = self.w1.to(device=device)
|
||||
self.w2 = self.w2.to(device=device)
|
||||
|
||||
if self.is_quantized():
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
self.w1_scale = self.w1_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
self.w2_scale = self.w2_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
if self.w1_scale is not None:
|
||||
self.w1_scale = self.w1_scale.to(device=device)
|
||||
if self.w2_scale is not None:
|
||||
self.w2_scale = self.w2_scale.to(device=device)
|
||||
|
||||
if self.w1_gs is not None:
|
||||
assert self.w2_gs is not None
|
||||
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
|
||||
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
|
||||
self.w1_gs = self.w1_gs.to(device=device)
|
||||
if self.w2_gs is not None:
|
||||
self.w2_gs = self.w2_gs.to(device=device)
|
||||
|
||||
def slice_weights(self, rank: int,
|
||||
num_local_experts: int) -> "WeightTensors":
|
||||
@@ -297,20 +295,12 @@ class WeightTensors:
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
|
||||
w1_scale, w2_scale = (None, None)
|
||||
if self.is_quantized():
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
w1_scale = self.w1_scale[s:e, :, :]
|
||||
w2_scale = self.w2_scale[s:e, :, :]
|
||||
|
||||
w1_gs = self.w1_gs
|
||||
w2_gs = self.w2_gs
|
||||
if w1_gs is not None:
|
||||
assert w2_gs is not None
|
||||
w1_gs = w1_gs[s:e]
|
||||
w2_gs = w2_gs[s:e]
|
||||
w1_scale = self.w1_scale[
|
||||
s:e, :, :] if self.w1_scale is not None else None
|
||||
w2_scale = self.w2_scale[
|
||||
s:e, :, :] if self.w2_scale is not None else None
|
||||
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
|
||||
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
|
||||
|
||||
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
|
||||
|
||||
@@ -323,7 +313,8 @@ class WeightTensors:
|
||||
in_dtype=config.dtype,
|
||||
quant_dtype=config.quant_dtype,
|
||||
block_shape=config.quant_block_shape,
|
||||
per_act_token_quant=config.is_per_out_ch_quant,
|
||||
per_out_ch_quant=config.
|
||||
is_per_act_token_quant, # or config.is_per_out_ch_quant
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
@@ -342,8 +333,6 @@ class RankTensors:
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor]
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Rank Tensors: \n"
|
||||
@@ -426,7 +415,6 @@ class RankTensors:
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
quant_config=config.quant_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
and config.supports_apply_weight_on_input())
|
||||
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones((num_experts, ),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
config: Config,
|
||||
vllm_config: VllmConfig,
|
||||
weights: WeightTensors,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
|
||||
def next_power_of_2(x):
|
||||
@@ -548,20 +542,20 @@ def make_modular_kernel(
|
||||
num_local_experts=config.num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=config.dtype,
|
||||
quant_config=config.quant_config,
|
||||
max_num_tokens=next_power_of_2(config.M),
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
|
||||
config.all2all_backend(), moe)
|
||||
config.all2all_backend(), moe,
|
||||
quant_config)
|
||||
|
||||
fused_experts = make_fused_experts(
|
||||
config.fused_experts_type,
|
||||
moe,
|
||||
quant_config,
|
||||
prepare_finalize.num_dispatchers(),
|
||||
weights.w1_gs,
|
||||
weights.w2_gs,
|
||||
config.N,
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
@@ -583,12 +577,38 @@ def run_modular_kernel(
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config, weights)
|
||||
if config.quant_dtype == "nvfp4":
|
||||
gscale = _make_gscale(config.num_local_experts)
|
||||
else:
|
||||
gscale = None
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
config.quant_dtype,
|
||||
w1_scale=rank_weights.w1_scale,
|
||||
w2_scale=rank_weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
g1_alphas=(1 / rank_weights.w1_gs)
|
||||
if rank_weights.w1_gs is not None else None,
|
||||
g2_alphas=(1 / rank_weights.w2_gs)
|
||||
if rank_weights.w2_gs is not None else None,
|
||||
a1_gscale=gscale,
|
||||
a2_gscale=gscale,
|
||||
block_shape=config.quant_block_shape,
|
||||
per_act_token_quant=config.is_per_act_token_quant,
|
||||
per_out_ch_quant=config.is_per_out_ch_quant,
|
||||
)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config, quant_config)
|
||||
|
||||
# impls might update the tensor in place
|
||||
hidden_states = rank_tensors.hidden_states.clone()
|
||||
|
||||
topk_ids = rank_tensors.topk_ids.to(
|
||||
mk.prepare_finalize.topk_indices_dtype())
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states":
|
||||
rank_tensors.hidden_states.clone(
|
||||
), # impls might update the tensor in place
|
||||
hidden_states,
|
||||
"w1":
|
||||
rank_weights.w1,
|
||||
"w2":
|
||||
@@ -596,15 +616,9 @@ def run_modular_kernel(
|
||||
"topk_weights":
|
||||
rank_tensors.topk_weights,
|
||||
"topk_ids":
|
||||
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
|
||||
topk_ids,
|
||||
"expert_map":
|
||||
rank_tensors.expert_map,
|
||||
"w1_scale":
|
||||
rank_weights.w1_scale,
|
||||
"w2_scale":
|
||||
rank_weights.w2_scale,
|
||||
"a1_scale":
|
||||
rank_tensors.hidden_states_scale,
|
||||
"global_num_experts":
|
||||
config.E,
|
||||
"apply_router_weight_on_input":
|
||||
|
||||
@@ -10,7 +10,8 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
|
||||
@@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str):
|
||||
quant_config_dict = config_dict['quant_config']
|
||||
del config_dict['quant_config']
|
||||
if quant_config_dict is None:
|
||||
quant_config = FusedMoEQuantConfig(None)
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
quant_config_dict = asdict(quant_config)
|
||||
|
||||
config_dict |= quant_config_dict
|
||||
|
||||
@@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMoEQuantConfig:
|
||||
quant_dtype: Union[torch.dtype, str, None]
|
||||
per_out_ch_quant: bool
|
||||
per_act_token_quant: bool
|
||||
block_shape: Optional[list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrepareFinalizeInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
@@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [
|
||||
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
|
||||
]
|
||||
common_float_and_int_types = common_float_types + [torch.int8]
|
||||
nv_fp4_types = ["nvfp4"]
|
||||
nvfp4_types = ["nvfp4"]
|
||||
fp8_types = [torch.float8_e4m3fn]
|
||||
|
||||
|
||||
@@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe()
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
standard_format,
|
||||
nv_fp4_types,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
force_multigpu=True,
|
||||
@@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe()
|
||||
register_experts(
|
||||
FlashInferExperts,
|
||||
standard_format,
|
||||
nv_fp4_types,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
@@ -306,39 +314,39 @@ if cutlass_fp4_supported():
|
||||
register_experts(
|
||||
CutlassExpertsFp4,
|
||||
standard_format,
|
||||
nv_fp4_types,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
MK_QUANT_CONFIGS = [
|
||||
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-channel / per-column weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# block-quantized weights and 128 block per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128]),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128]),
|
||||
# TODO (varun) : Should we test the following combinations ?
|
||||
# block-quantized weights and per-token activations
|
||||
# block-quantized weights and per-tensor activations
|
||||
@@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [
|
||||
|
||||
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||
MK_QUANT_CONFIGS += [
|
||||
FusedMoEQuantConfig(quant_dtype="nvfp4",
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype="nvfp4",
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
]
|
||||
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones((num_experts, ),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def make_prepare_finalize(
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
backend: Optional[str],
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
if backend != "naive" and backend is not None:
|
||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||
moe, quant_config)
|
||||
assert prepare_finalize is not None
|
||||
return prepare_finalize
|
||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
a1_gscale=_make_gscale(moe.num_local_experts),
|
||||
)
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
@@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
|
||||
return t[s:e]
|
||||
|
||||
|
||||
def make_cutlass_strides(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
return ab_strides1, ab_strides2, c_strides1, c_strides2
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
num_dispatchers: int,
|
||||
w1_gs: Optional[torch.Tensor],
|
||||
w2_gs: Optional[torch.Tensor],
|
||||
N: int,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
quant_kwargs = {
|
||||
"use_fp8_w8a8": use_fp8,
|
||||
"use_int8_w8a8": False,
|
||||
"use_int8_w8a16": False,
|
||||
"use_int4_w4a16": False,
|
||||
"block_shape": moe.block_shape,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||
|
||||
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
|
||||
|
||||
if fused_experts_type == BatchedDeepGemmExperts:
|
||||
kwargs = batch_kwargs | {
|
||||
"block_shape": moe.block_shape,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
}
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == BatchedTritonExperts:
|
||||
@@ -422,8 +429,8 @@ def make_fused_experts(
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == DeepGemmExperts:
|
||||
print("Making DeepGemmExperts () ...")
|
||||
experts = DeepGemmExperts()
|
||||
print("Making DeepGemmExperts {quant_config} ...")
|
||||
experts = DeepGemmExperts(quant_config)
|
||||
elif fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
print(f"Making TritonExperts {kwargs} ...")
|
||||
@@ -437,62 +444,50 @@ def make_fused_experts(
|
||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||
experts = NaiveBatchedExperts(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp8:
|
||||
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||
kwargs = {
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||
"block_shape": moe.block_shape,
|
||||
}
|
||||
"ab_strides1": strides[0],
|
||||
"ab_strides2": strides[1],
|
||||
"c_strides1": strides[2],
|
||||
"c_strides2": strides[3],
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
||||
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||
kwargs = {
|
||||
"max_experts_per_worker": moe.num_local_experts,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||
"block_shape": moe.block_shape,
|
||||
}
|
||||
"ab_strides1": strides[0],
|
||||
"ab_strides2": strides[1],
|
||||
"c_strides1": strides[2],
|
||||
"c_strides2": strides[3],
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassBatchedExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp4:
|
||||
assert w1_gs is not None and w2_gs is not None
|
||||
num_experts = moe.num_local_experts
|
||||
rank = moe.moe_parallel_config.dp_rank
|
||||
kwargs = {
|
||||
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
||||
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
||||
"a1_gscale": _make_gscale(num_experts),
|
||||
"a2_gscale": _make_gscale(num_experts),
|
||||
"max_experts_per_worker": num_experts,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||
"block_shape": moe.block_shape,
|
||||
"max_experts_per_worker": moe.num_local_experts,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
"out_dtype": moe.in_dtype,
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
||||
experts = CutlassExpertsFp4(**kwargs)
|
||||
elif fused_experts_type == FlashInferExperts:
|
||||
assert w1_gs is not None and w2_gs is not None
|
||||
num_experts = moe.num_local_experts
|
||||
rank = moe.moe_parallel_config.dp_rank
|
||||
kwargs = {
|
||||
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
||||
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
||||
"a1_gscale": _make_gscale(num_experts),
|
||||
"a2_gscale": _make_gscale(num_experts),
|
||||
"out_dtype": moe.in_dtype,
|
||||
"quant_dtype": "nvfp4",
|
||||
"ep_rank": moe.ep_rank,
|
||||
"ep_size": moe.ep_size,
|
||||
"tp_rank": moe.tp_rank,
|
||||
"tp_size": moe.tp_size,
|
||||
}
|
||||
} | quant_kwargs
|
||||
print(f"Making FlashInferExperts {kwargs} ...")
|
||||
experts = FlashInferExperts(**kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
||||
|
||||
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
|
||||
|
||||
return experts
|
||||
|
||||
@@ -6,6 +6,8 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
@@ -56,13 +58,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
per_act_token_quant=False,
|
||||
block_shape=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# triton (reference)
|
||||
triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
use_fp8_w8a8=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=BLOCK_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
|
||||
|
||||
@@ -73,8 +80,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
@@ -82,8 +87,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||
deepgemm_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
block_shape=BLOCK_SIZE,
|
||||
per_act_token_quant=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
|
||||
|
||||
@@ -94,8 +98,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
out_shape = (num_experts, max_tokens_per_expert, N)
|
||||
@@ -250,7 +250,7 @@ def test_fused_moe_batched_experts(
|
||||
block_shape=block_shape,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
if input_scales and quant_dtype is not None:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
w1, w2, quant_config = make_test_quant_config(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
use_mxfp4_w4a4=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
m_fused_moe = modular_triton_fused_moe(quant_config)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
@@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
w1_s,
|
||||
w2_s,
|
||||
quant_config.w1_scale,
|
||||
quant_config.w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
block_size,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
out = fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
|
||||
m_out = m_fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
)
|
||||
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
|
||||
|
||||
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
|
||||
tol = 0.035 if M < 40000 else 0.039
|
||||
# 0.039 only needed for M >= 8192
|
||||
tol = 0.035 if M < 8192 else 0.039
|
||||
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
||||
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
||||
|
||||
@@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
|
||||
@@ -4,12 +4,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
@@ -50,7 +50,7 @@ MNK_FACTORS = [
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4096, 512),
|
||||
(2048, 4096, 7168),
|
||||
(2048, 4096, 4096),
|
||||
]
|
||||
|
||||
E = [8, 24]
|
||||
@@ -117,31 +117,28 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
w1, w2, quant_config = make_test_quant_config(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
quant_dtype=torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
out = fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
|
||||
quant_config.w2_scale, score, topk,
|
||||
block_size)
|
||||
|
||||
# Check results
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import dataclasses
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
@@ -9,6 +10,8 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8, run_cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||
@@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
def slice_experts():
|
||||
slice_params = [
|
||||
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
|
||||
"c_strides2", "w1_scale", "w2_scale"
|
||||
"c_strides2"
|
||||
]
|
||||
full_tensors = {
|
||||
k: v
|
||||
@@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
if k in slice_params and k in cutlass_moe_kwargs
|
||||
}
|
||||
|
||||
quant_config = cutlass_moe_kwargs["quant_config"]
|
||||
|
||||
for i in range(0, num_experts, num_local_experts):
|
||||
s, e = i, i + num_local_experts
|
||||
|
||||
@@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
for k, t in full_tensors.items():
|
||||
cutlass_moe_kwargs[k] = t[s:e]
|
||||
|
||||
new_quant_config = copy.deepcopy(quant_config)
|
||||
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
|
||||
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
|
||||
|
||||
cutlass_moe_kwargs["quant_config"] = new_quant_config
|
||||
|
||||
yield cutlass_moe_kwargs
|
||||
|
||||
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
|
||||
@@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||
assert not any([
|
||||
t is None for t in [
|
||||
@@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
]
|
||||
])
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=moe_tensors.w1_scale,
|
||||
w2_scale=moe_tensors.w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
# Set to moe_tensors.a_scale iff static scales + per tensor.
|
||||
# This is not currently being tested.
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
'a': moe_tensors.a,
|
||||
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
'topk_weights': topk_weights,
|
||||
'topk_ids': topk_ids,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'c_strides2': moe_tensors.c_strides2,
|
||||
'per_act_token': per_act_token,
|
||||
'a1_scale': None #moe_tensors.a_scale
|
||||
'quant_config': quant_config,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
@@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(mt.a_d,
|
||||
mt.w1_d,
|
||||
mt.w2_d,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
|
||||
if ep_size is not None:
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
number_local_experts = e // ep_size
|
||||
else:
|
||||
number_local_experts = None
|
||||
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
|
||||
number_local_experts)
|
||||
per_out_ch, number_local_experts)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
@@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(mt.a_d,
|
||||
mt.w1_d,
|
||||
mt.w2_d,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
||||
per_act_token)
|
||||
per_act_token, per_out_ch)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
|
||||
@@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
@@ -71,9 +73,12 @@ def make_block_quant_fp8_weights(
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
|
||||
_) = make_test_weights(e, n, k, torch.bfloat16,
|
||||
_) = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
block_size)
|
||||
block_shape=block_size)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@@ -130,10 +135,11 @@ class TestTensors:
|
||||
config=config)
|
||||
|
||||
|
||||
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
max_tokens_per_rank: int, dp_size: int,
|
||||
hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig) -> FusedMoEModularKernel:
|
||||
def make_ll_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
|
||||
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
assert test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is not None
|
||||
@@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
num_dispatchers=pgi.world_size // dp_size,
|
||||
block_shape=test_config.block_size,
|
||||
per_act_token_quant=test_config.per_act_token_quant)
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
dp_size: int, num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig) -> FusedMoEModularKernel:
|
||||
def make_ht_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
assert not test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is None
|
||||
@@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
|
||||
fused_experts = DeepGemmExperts()
|
||||
fused_experts = DeepGemmExperts(quant_config)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int,
|
||||
test_tensors: TestTensors) -> FusedMoEModularKernel:
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
test_config = test_tensors.config
|
||||
@@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
dp_size=dp_size,
|
||||
hidden_size=hidden_size,
|
||||
q_dtype=q_dtype,
|
||||
test_config=test_config)
|
||||
test_config=test_config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
|
||||
q_dtype, test_config)
|
||||
mk = make_ht_modular_kernel(pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
test_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
return mk
|
||||
|
||||
@@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
return expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale=(None if test_config.low_latency else
|
||||
test_tensors.rank_token_scales),
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
num_local_experts=num_local_experts,
|
||||
test_tensors=test_tensors)
|
||||
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale = (None
|
||||
if test_config.low_latency else test_tensors.rank_token_scales)
|
||||
test_tensors=test_tensors,
|
||||
quant_config=quant_config)
|
||||
|
||||
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
@@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=False)
|
||||
return out
|
||||
|
||||
@@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor, block_shape: list[int]):
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
@@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_shape,
|
||||
quant_config=quant_config,
|
||||
# Make sure this is set to False so we
|
||||
# don't end up comparing the same implementation.
|
||||
allow_deep_gemm=False)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
@@ -129,11 +130,9 @@ def make_modular_kernel(
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
|
||||
is_quantized = q_dtype is not None
|
||||
|
||||
ht_args: Optional[DeepEPHTArgs] = None
|
||||
ll_args: Optional[DeepEPLLArgs] = None
|
||||
|
||||
@@ -159,24 +158,14 @@ def make_modular_kernel(
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
|
||||
if low_latency_mode:
|
||||
assert not per_act_token_quant, "not supported in ll mode"
|
||||
assert not quant_config.per_act_token_quant, "not supported in ll mode"
|
||||
fused_experts = BatchedTritonExperts(
|
||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||
num_dispatchers=num_dispatchers,
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
fused_experts = TritonExperts(
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
fused_experts = TritonExperts(quant_config=quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
@@ -217,11 +206,6 @@ def deep_ep_moe_impl(
|
||||
if is_quantized:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
|
||||
|
||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||
|
||||
@@ -236,6 +220,19 @@ def deep_ep_moe_impl(
|
||||
rank_token_scales_chunk = rank_token_scales_chunk[
|
||||
chunk_start:chunk_end]
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
q_dtype,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
a1_scale=rank_token_scales_chunk,
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
|
||||
|
||||
out = mk.forward(hidden_states=rank_tokens_chunk,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
@@ -245,12 +242,6 @@ def deep_ep_moe_impl(
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
a1_scale=rank_token_scales_chunk,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=False)
|
||||
|
||||
if not skip_result_store:
|
||||
@@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
@requires_deep_ep
|
||||
def test_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
mnk: tuple[int, int, int],
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
@@ -424,7 +417,6 @@ def test_deep_ep_moe(
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
m, n, k = mnk
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
@@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
num_experts: int, topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool):
|
||||
|
||||
def test_low_latency_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool,
|
||||
):
|
||||
low_latency_mode = True
|
||||
m, n, k = mnk
|
||||
|
||||
if (low_latency_mode
|
||||
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
|
||||
|
||||
@@ -11,6 +11,8 @@ import math
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# triton reference
|
||||
out_triton = fused_experts(
|
||||
hidden_states=tokens_bf16,
|
||||
@@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
@@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
|
||||
# Note: W1 has shape (E, 2N, K), so N = 512
|
||||
# can trigger the deepgemm path.
|
||||
# Note: N <= 512 will disable the deepgemm path due to performance issues.
|
||||
MNKs = [
|
||||
(1024, 768, 128),
|
||||
(1024, 768, 512),
|
||||
@@ -144,15 +144,15 @@ TOPKS = [2, 6]
|
||||
NUM_EXPERTS = [32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(),
|
||||
reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
_fused_moe_mod = importlib.import_module(
|
||||
"vllm.model_executor.layers.fused_moe.fused_moe")
|
||||
@@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
||||
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
|
||||
_spy_deep_gemm_moe_fp8)
|
||||
|
||||
m, n, k = mnk
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
@@ -145,6 +147,14 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
@@ -153,15 +163,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
||||
@@ -210,6 +215,14 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
@@ -218,15 +231,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
td.layer.dp_size = 1
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
@@ -41,7 +41,6 @@ MNK_FACTORS = [
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
#@pytest.mark.parametrize("e", [128, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
@@ -56,16 +55,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
|
||||
quant_blocksize = 16
|
||||
|
||||
(_, w1_q, w1_blockscale,
|
||||
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a,
|
||||
@@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
FlashInferExperts(
|
||||
a1_gscale=a1_gs,
|
||||
g1_alphas=(1 / w1_gs),
|
||||
a2_gscale=a2_gs,
|
||||
g2_alphas=(1 / w2_gs),
|
||||
out_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
))
|
||||
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
|
||||
)
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w1_scale=w1_blockscale,
|
||||
w2=w2_q,
|
||||
w2_scale=w2_blockscale,
|
||||
a1_scale=a1_gs,
|
||||
a2_scale=a2_gs,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
@@ -122,18 +102,18 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
@@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2,
|
||||
)
|
||||
|
||||
out_triton_monolithic = triton_kernel_moe_forward(
|
||||
hidden_states=x_tri,
|
||||
w1=w1_tri,
|
||||
@@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
gating_output=exp_data_tri,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
out_triton_monolithic = out_triton_monolithic[..., :K]
|
||||
|
||||
@@ -336,6 +341,13 @@ def batched_moe(
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens,
|
||||
@@ -344,19 +356,12 @@ def batched_moe(
|
||||
rank=0,
|
||||
),
|
||||
BatchedOAITritonExperts(
|
||||
None,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
extra_expert_args = {
|
||||
"w1_bias": w1_bias,
|
||||
"w2_bias": w2_bias,
|
||||
}
|
||||
|
||||
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
|
||||
|
||||
return fused_experts(
|
||||
@@ -365,7 +370,6 @@ def batched_moe(
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
extra_expert_args=extra_expert_args,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
@@ -22,7 +21,8 @@ from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
||||
run_modular_kernel)
|
||||
from .modular_kernel_tools.mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
|
||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
|
||||
expert_info)
|
||||
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
||||
parallel_launch_with_config)
|
||||
|
||||
@@ -55,7 +55,7 @@ def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
base_config: Config,
|
||||
weights: WeightTensors,
|
||||
verbose: bool,
|
||||
):
|
||||
@@ -63,42 +63,44 @@ def rank_worker(
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
if base_config.fused_moe_chunk_size is not None:
|
||||
assert (
|
||||
base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
Ms = base_config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
TOPKs = base_config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
exceptions = []
|
||||
count = 0
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
# override m and topk
|
||||
config = copy.deepcopy(base_config)
|
||||
config.Ms = m
|
||||
config.topks = topk
|
||||
|
||||
try:
|
||||
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
|
||||
count = count + 1
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
rank_tensors = RankTensors.make(config, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
|
||||
rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
ref_out = reference_moe_impl(config, weights, rank_tensors)
|
||||
|
||||
if config.quant_dtype == "nvfp4":
|
||||
atol = 1e-1
|
||||
rtol = 1e-1
|
||||
atol = 1e-1 if config.K < 4096 else 2e-1
|
||||
rtol = 1e-1 if config.K < 4096 else 2e-1
|
||||
else:
|
||||
atol = 3e-2
|
||||
rtol = 3e-2
|
||||
@@ -132,7 +134,7 @@ Ms = [32, 64]
|
||||
# hidden sizes, making this too large will cause fp4 tests to fail.
|
||||
# Also needs to be a multiple of 1024 for deep_gemm.
|
||||
Ks = [2048]
|
||||
Ns = [2048]
|
||||
Ns = [1024]
|
||||
TOPKs = [4, 1]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
@@ -167,7 +169,7 @@ def is_nyi_config(config: Config) -> bool:
|
||||
@meets_multi_gpu_requirements
|
||||
def test_modular_kernel_combinations_multigpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||
@@ -208,7 +210,7 @@ def test_modular_kernel_combinations_multigpu(
|
||||
@pytest.mark.parametrize("world_size", [1])
|
||||
def test_modular_kernel_combinations_singlegpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||
|
||||
@@ -15,11 +15,14 @@ from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
@@ -187,14 +190,9 @@ def test_fused_moe(
|
||||
#
|
||||
# Setup test functions
|
||||
#
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
use_mxfp4_w4a4=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None)
|
||||
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
|
||||
|
||||
def m_fused_moe(
|
||||
a: torch.Tensor,
|
||||
@@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
if weight_bits == 4:
|
||||
quant_config_builder = int4_w4a16_moe_quant_config
|
||||
else:
|
||||
assert weight_bits == 8
|
||||
quant_config_builder = int8_w8a16_moe_quant_config
|
||||
|
||||
quant_config = quant_config_builder(w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
triton_output = fused_moe(a,
|
||||
w1_qweight,
|
||||
@@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
quant_config=quant_config)
|
||||
torch_output = torch_moe(a,
|
||||
w1_ref,
|
||||
w2_ref,
|
||||
|
||||
@@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
@@ -56,7 +57,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
@@ -73,18 +74,22 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
g1_alphas=(1 / w1_gs),
|
||||
g2_alphas=(1 / w2_gs),
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
)
|
||||
|
||||
cutlass_output = cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
w1_fp4=w1_q,
|
||||
w1_blockscale=w1_blockscale,
|
||||
g1_alphas=(1 / w1_gs),
|
||||
a2_gscale=a2_gs,
|
||||
w2_fp4=w2_q,
|
||||
w2_blockscale=w2_blockscale,
|
||||
g2_alphas=(1 / w2_gs),
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=quant_config,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
|
||||
@@ -9,6 +9,8 @@ import torch
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
@@ -143,10 +145,16 @@ def pplx_cutlass_moe(
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
||||
out_dtype, per_act_token, per_out_ch,
|
||||
ab_strides1, ab_strides2, c_strides1,
|
||||
c_strides2)
|
||||
experts = CutlassBatchedExpertsFp8(
|
||||
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
|
||||
ab_strides2, c_strides1, c_strides2,
|
||||
fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||
if per_act_token else a1_scale[rank]))
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@@ -167,10 +175,7 @@ def pplx_cutlass_moe(
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None, #TODO
|
||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||
if per_act_token else a1_scale[rank])
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
|
||||
]
|
||||
|
||||
PPLX_COMBOS = [
|
||||
# TODO: figure out why this fails, seems to be test problem
|
||||
# TODO(bnell): figure out why this fails, seems to be test problem
|
||||
#(1, 128, 128),
|
||||
(2, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
@@ -360,18 +360,18 @@ def pplx_prepare_finalize(
|
||||
|
||||
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
|
||||
a_chunk,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
num_experts,
|
||||
None,
|
||||
False,
|
||||
FusedMoEQuantConfig(
|
||||
FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
False,
|
||||
block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_shape,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -540,20 +540,6 @@ def pplx_moe(
|
||||
|
||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||
a_chunk = chunk_by_rank(a, rank, world_size)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
|
||||
@@ -567,6 +553,28 @@ def pplx_moe(
|
||||
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
|
||||
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
a1_scale=a1_scale_chunk,
|
||||
a2_scale=a2_scale_chunk,
|
||||
)
|
||||
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
@@ -585,10 +593,6 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
a1_scale=a1_scale_chunk,
|
||||
a2_scale=a2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if use_cudagraphs:
|
||||
@@ -605,10 +609,6 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
a1_scale=a1_scale_chunk,
|
||||
a2_scale=a2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@@ -820,7 +820,7 @@ def test_pplx_moe_slow(
|
||||
k,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
|
||||
@@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
k,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
args["w1"] = w1
|
||||
args["w2"] = w2
|
||||
|
||||
@@ -7,10 +7,12 @@ import itertools
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
@@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_fp8_w8a8=True, # using fp8
|
||||
per_channel_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
quant_config=fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
),
|
||||
)
|
||||
|
||||
# Check results
|
||||
|
||||
@@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
@@ -34,18 +35,22 @@ def triton_moe(
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_channel_quant=per_act_token_quant,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape)
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
@@ -64,6 +69,16 @@ def batched_moe(
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
@@ -72,21 +87,11 @@ def batched_moe(
|
||||
BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def naive_batched_moe(
|
||||
@@ -105,6 +110,16 @@ def naive_batched_moe(
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
@@ -113,21 +128,11 @@ def naive_batched_moe(
|
||||
NaiveBatchedExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
@@ -216,7 +221,7 @@ def make_test_weight(
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
@@ -228,7 +233,7 @@ def make_test_weight(
|
||||
w_gs_l = [None] * e
|
||||
for idx in range(e):
|
||||
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
||||
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
||||
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
|
||||
|
||||
w = torch.stack(w_l)
|
||||
w_s = torch.stack(w_s_l)
|
||||
@@ -258,16 +263,16 @@ def make_test_weights(
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]]:
|
||||
return (
|
||||
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
per_out_ch_quant),
|
||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
per_out_ch_quant),
|
||||
)
|
||||
|
||||
|
||||
@@ -285,6 +290,76 @@ def per_token_cast_to_fp8(
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def make_test_quant_config(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype,
|
||||
quant_dtype,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Hacky/trivial scales for nvfp4.
|
||||
a1_gscale: Optional[torch.Tensor] = None
|
||||
a2_gscale: Optional[torch.Tensor] = None
|
||||
if quant_dtype == "nvfp4":
|
||||
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a1_scale = a1_gscale
|
||||
a2_scale = a2_gscale
|
||||
else:
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
return w1, w2, FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
# TODO: make sure this is handled properly
|
||||
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
|
||||
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool = False,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
|
||||
renormalize)
|
||||
return fused_experts(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
# CustomOp?
|
||||
class BaselineMM(torch.nn.Module):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user