Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,18 +9,19 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from .common import Config
from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES,
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
from .mk_objects import (
MK_ALL_PREPARE_FINALIZE_TYPES,
MK_FUSED_EXPERT_TYPES,
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
)
def make_config_arg_parser(description: str):
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
if pf.__name__ == s:
return pf
raise ValueError(
f"Cannot find a PrepareFinalize type that matches {s}")
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
for fe in MK_FUSED_EXPERT_TYPES:
@@ -45,15 +46,18 @@ def make_config_arg_parser(description: str):
"--pf-type",
type=to_pf_class_type,
required=True,
help=("Choose a PrepareFinalize Type : "
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"),
help=(
"Choose a PrepareFinalize Type : "
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"
),
)
parser.add_argument(
"--experts-type",
type=to_experts_class_type,
required=True,
help=(f"Choose a FusedExpert type : "
f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"),
help=(
f"Choose a FusedExpert type : {[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"
),
)
parser.add_argument(
"-m",
@@ -74,66 +78,65 @@ def make_config_arg_parser(description: str):
default=1024,
help="N dimension of the first fused-moe matmul",
)
parser.add_argument("--num-experts",
type=int,
default=32,
help="Global num experts")
parser.add_argument("--topk",
nargs="+",
type=int,
default=[4, 1],
help="num topk")
parser.add_argument(
"--num-experts", type=int, default=32, help="Global num experts"
)
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
parser.add_argument(
"--fused-moe-chunk-size",
type=int,
help="Fused moe chunk size used for the non-batched fused experts impl."
help="Fused moe chunk size used for the non-batched fused experts impl.",
)
# Quant args
parser.add_argument("--quant-dtype",
type=to_quant_torch_dtype,
help="Quant datatype")
parser.add_argument("--per-token-quantized-activations",
action='store_true',
help=("The input activations must be per-token "
"quantized"))
parser.add_argument("--per-channel-quantized-weights",
action="store_true",
help="The weights must be per-channel quantized.")
parser.add_argument("--block-shape",
nargs="+",
type=int,
help="Quantization block shape")
parser.add_argument(
"--quant-dtype", type=to_quant_torch_dtype, help="Quant datatype"
)
parser.add_argument(
"--per-token-quantized-activations",
action="store_true",
help=("The input activations must be per-token quantized"),
)
parser.add_argument(
"--per-channel-quantized-weights",
action="store_true",
help="The weights must be per-channel quantized.",
)
parser.add_argument(
"--block-shape", nargs="+", type=int, help="Quantization block shape"
)
# Torch trace profile generation args
parser.add_argument("--torch-trace-dir-path",
type=str,
default=None,
help="Get torch trace for single execution")
parser.add_argument(
"--torch-trace-dir-path",
type=str,
default=None,
help="Get torch trace for single execution",
)
return parser
def _validate_args(args: argparse.Namespace):
if args.quant_dtype is not None:
assert args.quant_dtype == torch.float8_e4m3fn
if args.block_shape is not None:
assert len(args.block_shape) == 2, (
f"block shape must have 2 elements. got {args.block_shape}")
f"block shape must have 2 elements. got {args.block_shape}"
)
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
assert args.world_size == 1, (
"Single GPU objects need world size set to 1")
assert args.world_size == 1, "Single GPU objects need world size set to 1"
if args.torch_trace_dir_path is not None:
from pathlib import Path
assert Path(args.torch_trace_dir_path).is_dir(), (
f"Please create {args.torch_trace_dir_path}")
f"Please create {args.torch_trace_dir_path}"
)
def make_config(args: argparse.Namespace) -> Config:
_validate_args(args)
quant_config = None
@@ -142,7 +145,8 @@ def make_config(args: argparse.Namespace) -> Config:
quant_dtype=args.quant_dtype,
per_act_token_quant=args.per_token_quantized_activations,
per_out_ch_quant=args.per_channel_quantized_weights,
block_shape=args.block_shape)
block_shape=args.block_shape,
)
return Config(
Ms=args.m,
@@ -156,4 +160,5 @@ def make_config(args: argparse.Namespace) -> Config:
fused_experts_type=args.experts_type,
fused_moe_chunk_size=args.fused_moe_chunk_size,
world_size=args.world_size,
torch_trace_dir_path=args.torch_trace_dir_path)
torch_trace_dir_path=args.torch_trace_dir_path,
)

View File

@@ -8,20 +8,30 @@ import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .mk_objects import (
TestMoEQuantConfig,
expert_info,
make_fused_experts,
make_prepare_finalize,
prepare_finalize_info,
)
from .parallel_utils import ProcessGroupInfo
@@ -94,8 +104,7 @@ class Config:
@property
def is_per_tensor_act_quant(self) -> bool:
return (not self.is_per_act_token_quant
and self.quant_block_shape is None)
return not self.is_per_act_token_quant and self.quant_block_shape is None
@property
def is_per_out_ch_quant(self) -> bool:
@@ -134,23 +143,24 @@ class Config:
if self.fused_moe_chunk_size is not None:
env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
)
return vllm_config, env_dict
def is_fp8_block_quantized(self):
return (self.quant_dtype == torch.float8_e4m3fn
and self.quant_block_shape is not None)
return (
self.quant_dtype == torch.float8_e4m3fn
and self.quant_block_shape is not None
)
def is_batched_prepare_finalize(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return (mk.FusedMoEActivationFormat.BatchedExperts ==
info.activation_format)
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
def is_batched_fused_experts(self):
info = expert_info(self.fused_experts_type)
return (mk.FusedMoEActivationFormat.BatchedExperts ==
info.activation_format)
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
def is_standard_fused_experts(self):
info = expert_info(self.fused_experts_type)
@@ -190,8 +200,10 @@ class Config:
def needs_deep_ep(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return (info.backend == "deepep_high_throughput"
or info.backend == "deepep_low_latency")
return (
info.backend == "deepep_high_throughput"
or info.backend == "deepep_low_latency"
)
def all2all_backend(self):
info = prepare_finalize_info(self.prepare_finalize_type)
@@ -211,20 +223,26 @@ class Config:
return False
# Check quantization sanity
if (int(self.is_per_act_token_quant) +
int(self.is_per_tensor_act_quant) +
int(self.quant_block_shape is not None)) > 1:
if (
int(self.is_per_act_token_quant)
+ int(self.is_per_tensor_act_quant)
+ int(self.quant_block_shape is not None)
) > 1:
# invalid quant config
return False
# check type support
if self.quant_dtype is None:
if (self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types()):
if (
self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types()
):
return False
else:
if (self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types()):
if (
self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types()
):
return False
# Check block quanization support
@@ -261,18 +279,21 @@ class WeightTensors:
def describe(self):
s = ""
s += "== Weight Tensors: \n"
s += f' - {_describe_tensor(self.w1, "w1")} \n'
s += f' - {_describe_tensor(self.w2, "w2")} \n'
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
s += f" - {_describe_tensor(self.w1, 'w1')} \n"
s += f" - {_describe_tensor(self.w2, 'w2')} \n"
s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n"
s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n"
s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n"
s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n"
return s
def is_quantized(self) -> bool:
# or w1_scale is not None?
return (self.w1.dtype == torch.float8_e4m3fn
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
return (
self.w1.dtype == torch.float8_e4m3fn
or self.w1.dtype == torch.uint8
or self.w1.dtype == torch.int8
)
def to_current_device(self):
device = torch.cuda.current_device()
@@ -289,16 +310,13 @@ class WeightTensors:
if self.w2_gs is not None:
self.w2_gs = self.w2_gs.to(device=device)
def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors":
def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors":
s = rank * num_local_experts
e = s + num_local_experts
w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :]
w1_scale = self.w1_scale[
s:e, :, :] if self.w1_scale is not None else None
w2_scale = self.w2_scale[
s:e, :, :] if self.w2_scale is not None else None
w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None
w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
@@ -313,15 +331,11 @@ class WeightTensors:
in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_out_ch_quant=config.
is_per_act_token_quant, # or config.is_per_out_ch_quant
per_out_ch_quant=config.is_per_act_token_quant, # or config.is_per_out_ch_quant
)
return WeightTensors(
w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_gs=w1_gs,
w2_gs=w2_gs)
@dataclass
@@ -336,22 +350,22 @@ class RankTensors:
def describe(self):
s = ""
s += "== Rank Tensors: \n"
s += f' - {_describe_tensor(self.hidden_states, "HS")} \n'
s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n'
s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n'
s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n'
s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n'
s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n"
s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n"
s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n"
s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n"
s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n"
return s
@staticmethod
def make_hidden_states(
config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
config: Config,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Return hidden_states
"""
m, k, dtype = (config.M, config.K, config.dtype)
a = (torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0)
a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0
if config.quant_dtype is None:
return a, None
@@ -362,36 +376,29 @@ class RankTensors:
# first - so further quantize and dequantize will yield the same
# values.
if config.is_per_tensor_act_quant:
a_q, a_scales = ops.scaled_fp8_quant(
a, use_per_token_if_dynamic=False)
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False)
return a_q.float().mul(a_scales).to(dtype), a_scales
if config.is_per_act_token_quant:
a_q, a_scales = ops.scaled_fp8_quant(a,
use_per_token_if_dynamic=True)
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
return a_q.float().mul(a_scales).to(dtype), None
assert config.quant_block_shape is not None
block_k = config.quant_block_shape[1]
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
return a_q.float().view(
(-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None
return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(
dtype
), None
@staticmethod
def make(config: Config, pgi: ProcessGroupInfo):
dtype = config.dtype
topk, m, _ = (config.topk, config.M, config.K)
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(
config)
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config)
num_local_experts, global_num_experts = (config.num_local_experts,
config.E)
score = torch.randn((m, global_num_experts),
device="cuda",
dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
False)
num_local_experts, global_num_experts = (config.num_local_experts, config.E)
score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
# distribute topk_ids evenly
for mi in range(m):
@@ -400,14 +407,15 @@ class RankTensors:
expert_map = None
if config.world_size > 1 and config.supports_expert_map():
expert_map = torch.full((global_num_experts, ),
fill_value=-1,
dtype=torch.int32)
expert_map = torch.full(
(global_num_experts,), fill_value=-1, dtype=torch.int32
)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
expert_map = expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
expert_map = expert_map.to(
device=torch.cuda.current_device(), dtype=torch.int32
)
return RankTensors(
hidden_states=hidden_states,
@@ -418,9 +426,9 @@ class RankTensors:
)
def reference_moe_impl(config: Config, weights: WeightTensors,
rank_tensors: RankTensors) -> torch.Tensor:
def reference_moe_impl(
config: Config, weights: WeightTensors, rank_tensors: RankTensors
) -> torch.Tensor:
if config.quant_dtype == "nvfp4":
quant_blocksize = 16
dtype = config.dtype
@@ -433,8 +441,10 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
w2_blockscale = weights.w2_scale
w2_gs = weights.w2_gs
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX)
/ torch.amax(rank_tensors.hidden_states.flatten(), dim=-1)
).to(torch.float32)
assert w1_gs is not None
assert w2_gs is not None
@@ -447,14 +457,17 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
assert w2_blockscale.shape[2] % 4 == 0
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
rank_tensors.hidden_states, a_global_scale)
rank_tensors.hidden_states, a_global_scale
)
a = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=dtype,
device=a_fp4.device,
block_size=quant_blocksize)
a = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=dtype,
device=a_fp4.device,
block_size=quant_blocksize,
)
e = w1_q.shape[0]
n = w1_q.shape[1] // 2
@@ -464,18 +477,22 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
w1[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize,
)
w2[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize,
)
a_scale = None
w1_scale = None
w2_scale = None
@@ -493,27 +510,29 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
per_act_token_quant = config.is_per_act_token_quant
block_shape = config.quant_block_shape
return torch_experts(a=a,
w1=w1,
w2=w2,
topk_weight=rank_tensors.topk_weights,
topk_ids=rank_tensors.topk_ids,
global_num_experts=config.E,
expert_map=None,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
apply_router_weights_on_input=config.topk == 1
and config.supports_apply_weight_on_input())
return torch_experts(
a=a,
w1=w1,
w2=w2,
topk_weight=rank_tensors.topk_weights,
topk_ids=rank_tensors.topk_ids,
global_num_experts=config.E,
expert_map=None,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
apply_router_weights_on_input=config.topk == 1
and config.supports_apply_weight_on_input(),
)
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
return torch.ones(
(num_experts,), device=torch.cuda.current_device(), dtype=torch.float32
)
def make_modular_kernel(
@@ -521,12 +540,12 @@ def make_modular_kernel(
vllm_config: VllmConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel:
def next_power_of_2(x):
import math
if x == 0:
return 1
return 2**math.ceil(math.log2(x))
return 2 ** math.ceil(math.log2(x))
# make moe config
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
@@ -546,9 +565,9 @@ def make_modular_kernel(
)
# make modular kernel
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
config.all2all_backend(), moe,
quant_config)
prepare_finalize = make_prepare_finalize(
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
)
fused_experts = make_fused_experts(
config.fused_experts_type,
@@ -559,7 +578,8 @@ def make_modular_kernel(
)
modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
prepare_finalize=prepare_finalize, fused_experts=fused_experts
)
return modular_kernel
@@ -587,10 +607,8 @@ def run_modular_kernel(
w1_scale=rank_weights.w1_scale,
w2_scale=rank_weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
g1_alphas=(1 / rank_weights.w1_gs)
if rank_weights.w1_gs is not None else None,
g2_alphas=(1 / rank_weights.w2_gs)
if rank_weights.w2_gs is not None else None,
g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None,
g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None,
a1_gscale=gscale,
a2_gscale=gscale,
block_shape=config.quant_block_shape,
@@ -603,38 +621,30 @@ def run_modular_kernel(
# impls might update the tensor in place
hidden_states = rank_tensors.hidden_states.clone()
topk_ids = rank_tensors.topk_ids.to(
mk.prepare_finalize.topk_indices_dtype())
topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype())
mk_kwargs = {
"hidden_states":
hidden_states,
"w1":
rank_weights.w1,
"w2":
rank_weights.w2,
"topk_weights":
rank_tensors.topk_weights,
"topk_ids":
topk_ids,
"expert_map":
rank_tensors.expert_map,
"global_num_experts":
config.E,
"apply_router_weight_on_input":
config.topk == 1 and config.supports_apply_weight_on_input(),
"hidden_states": hidden_states,
"w1": rank_weights.w1,
"w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights,
"topk_ids": topk_ids,
"expert_map": rank_tensors.expert_map,
"global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1
and config.supports_apply_weight_on_input(),
}
num_tokens = rank_tensors.hidden_states.shape[0]
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
device="cuda",
dtype=torch.int)
num_tokens_across_dp = torch.tensor(
[num_tokens] * config.world_size, device="cuda", dtype=torch.int
)
with set_forward_context(
None,
vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
None,
vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
out = mk.forward(**mk_kwargs)

View File

@@ -10,14 +10,21 @@ import torch
from tqdm import tqdm
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG)
from vllm.model_executor.layers.fused_moe.config import FUSED_MOE_UNQUANTIZED_CONFIG
from vllm.platforms import current_platform
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
run_modular_kernel)
from .mk_objects import (MK_FUSED_EXPERT_TYPES,
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS)
from .common import (
Config,
RankTensors,
WeightTensors,
reference_moe_impl,
run_modular_kernel,
)
from .mk_objects import (
MK_FUSED_EXPERT_TYPES,
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS,
)
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
@@ -38,8 +45,9 @@ def rank_worker(
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device
weights.to_current_device()
@@ -60,8 +68,7 @@ def rank_worker(
rank_tensors = RankTensors.make(cfgx, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
rank_tensors)
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
@@ -70,28 +77,27 @@ def rank_worker(
def make_feature_matrix(csv_file_path: str):
from dataclasses import asdict
import pandas as pd
def add_to_results(config: Config,
success: Result,
results_df: Optional[pd.DataFrame] = None):
def add_to_results(
config: Config, success: Result, results_df: Optional[pd.DataFrame] = None
):
config_dict = asdict(config)
config_dict['prepare_finalize_type'] = config_dict[
'prepare_finalize_type'].__name__
config_dict['fused_experts_type'] = config_dict[
'fused_experts_type'].__name__
config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant
quant_config_dict = config_dict['quant_config']
del config_dict['quant_config']
config_dict["prepare_finalize_type"] = config_dict[
"prepare_finalize_type"
].__name__
config_dict["fused_experts_type"] = config_dict["fused_experts_type"].__name__
config_dict["per_tensor_act_quant"] = config.is_per_tensor_act_quant
quant_config_dict = config_dict["quant_config"]
del config_dict["quant_config"]
if quant_config_dict is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
quant_config_dict = asdict(quant_config)
config_dict |= quant_config_dict
result_dict = config_dict | {'success': success.name}
result_dict = config_dict | {"success": success.name}
result_df = pd.DataFrame([result_dict])
if results_df is None:
@@ -112,22 +118,26 @@ def make_feature_matrix(csv_file_path: str):
Q_TYPES = MK_QUANT_CONFIGS
combinations = list(
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES))
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)
)
results_df: Optional[pd.DataFrame] = None
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
combinations): #noqa: E501
config = Config(Ms=[m],
K=k,
N=n,
E=e,
topks=topks,
dtype=dtype,
prepare_finalize_type=pf_type,
fused_experts_type=experts_type,
quant_config=quant_config,
world_size=2,
fused_moe_chunk_size=None)
combinations
): # noqa: E501
config = Config(
Ms=[m],
K=k,
N=n,
E=e,
topks=topks,
dtype=dtype,
prepare_finalize_type=pf_type,
fused_experts_type=experts_type,
quant_config=quant_config,
world_size=2,
fused_moe_chunk_size=None,
)
success = None
if config.is_valid():
@@ -135,9 +145,14 @@ def make_feature_matrix(csv_file_path: str):
try:
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker,
vllm_config, env_dict, config,
weights)
parallel_launch_with_config(
config.world_size,
rank_worker,
vllm_config,
env_dict,
config,
weights,
)
success = Result.PASS
except Exception as _:
success = Result.FAIL
@@ -150,25 +165,33 @@ def make_feature_matrix(csv_file_path: str):
results_df.to_csv(f"{csv_file_path}")
if __name__ == '__main__':
if __name__ == "__main__":
import argparse
from pathlib import Path
parser = argparse.ArgumentParser(description=(
"Make ModularKernel feature matrix \n"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501
"-f ./feature_matrices/feature_matrix.csv"))
parser.add_argument("-f",
"--feature-matrix-csv-file-path",
type=str,
required=True,
help="File name to Generate a .csv file")
parser = argparse.ArgumentParser(
description=(
"Make ModularKernel feature matrix \n"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " # noqa: E501
"-f ./feature_matrices/feature_matrix.csv"
)
)
parser.add_argument(
"-f",
"--feature-matrix-csv-file-path",
type=str,
required=True,
help="File name to Generate a .csv file",
)
args = parser.parse_args()
csv_path = args.feature_matrix_csv_file_path
assert csv_path.endswith(
'csv'), f"Need a file path ending with .csv, got {csv_path}"
assert Path(csv_path).parent.is_dir(
), f"Cannot find parent directory for {Path(csv_path).parent}"
assert csv_path.endswith("csv"), (
f"Need a file path ending with .csv, got {csv_path}"
)
assert Path(csv_path).parent.is_dir(), (
f"Cannot find parent directory for {Path(csv_path).parent}"
)
make_feature_matrix(args.feature_matrix_csv_file_path)

View File

@@ -8,24 +8,33 @@ import torch
# Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
BatchedTritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
BatchedTritonExperts,
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
cutlass_fp4_supported,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
cutlass_fp8_supported,
)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.deep_gemm import is_deep_gemm_supported
@@ -60,8 +69,7 @@ class ExpertInfo:
needs_deep_gemm: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
PrepareFinalizeInfo] = {}
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
@@ -71,7 +79,10 @@ MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
torch.float8_e4m3fn,
torch.bfloat16,
torch.float16,
torch.float32,
]
common_float_and_int_types = common_float_types + [torch.int8]
nvfp4_types = ["nvfp4"]
@@ -186,9 +197,11 @@ register_experts(
# Disable on blackwell for now
if has_deep_ep() and not current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
DeepEPHTPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
DeepEPLLPrepareAndFinalize,
)
register_prepare_and_finalize(
DeepEPHTPrepareAndFinalize,
@@ -208,7 +221,9 @@ if has_deep_ep() and not current_platform.has_device_capability(100):
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
PplxPrepareAndFinalize,
)
register_prepare_and_finalize(
PplxPrepareAndFinalize,
batched_format,
@@ -217,13 +232,14 @@ if has_pplx():
backend="pplx",
)
if (has_flashinfer_cutlass_fused_moe()
and current_platform.has_device_capability(100)):
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
create_flashinfer_prepare_finalize)
create_flashinfer_prepare_finalize,
)
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
@@ -258,16 +274,18 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant=False,
needs_deep_gemm=True,
)
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
(
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
)
register_experts(
BatchedTritonOrDeepGemmExperts,
batched_format,
@@ -290,8 +308,11 @@ if has_deep_gemm() and is_deep_gemm_supported():
)
if cutlass_fp8_supported():
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
CutlassExpertsFp8)
from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
)
register_experts(
CutlassExpertsFp8,
standard_format,
@@ -310,8 +331,8 @@ if cutlass_fp8_supported():
)
if cutlass_fp4_supported():
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp4
register_experts(
CutlassExpertsFp4,
standard_format,
@@ -324,30 +345,40 @@ if cutlass_fp4_supported():
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
None,
# per-channel / per-column weights and per-tensor activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=False,
block_shape=None,
),
# per-channel / per-column weights and per-token activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=True,
per_act_token_quant=True,
block_shape=None,
),
# per-tensor weights and per-tensor activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None,
),
# per-tensor weights and per-token activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=True,
block_shape=None,
),
# block-quantized weights and 128 block per-token activations
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128]),
TestMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=[128, 128],
),
# TODO (varun) : Should we test the following combinations ?
# block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations
@@ -355,10 +386,12 @@ MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [
TestMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
TestMoEQuantConfig(
quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None,
),
]
@@ -370,12 +403,14 @@ def make_prepare_finalize(
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
moe, quant_config)
moe, quant_config
)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return create_flashinfer_prepare_finalize(
use_dp=moe.moe_parallel_config.dp_size > 1)
use_dp=moe.moe_parallel_config.dp_size > 1
)
else:
return MoEPrepareAndFinalizeNoEP()
@@ -391,10 +426,10 @@ def make_cutlass_strides(
n: int,
k: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
return ab_strides1, ab_strides2, c_strides1, c_strides2
@@ -405,7 +440,6 @@ def make_fused_experts(
num_dispatchers: int,
N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute:
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,

View File

@@ -6,13 +6,11 @@ import traceback
from typing import Any, Callable, Optional
import torch
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.utils import get_open_port
## Parallel Processes Utils
@@ -30,10 +28,11 @@ class ProcessGroupInfo:
device: torch.device
def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
local_rank: int):
def _set_vllm_config(
vllm_config: VllmConfig, world_size: int, rank: int, local_rank: int
):
import tempfile
temp_file = tempfile.mkstemp()[1]
with set_current_vllm_config(vllm_config):
@@ -46,13 +45,10 @@ def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
)
initialize_model_parallel(
tensor_model_parallel_size=vllm_config.parallel_config.
tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.
pipeline_parallel_size,
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
)
cpu_group = torch.distributed.new_group(list(range(world_size)),
backend="gloo")
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
return cpu_group
@@ -62,8 +58,7 @@ def _worker_parallel_launch(
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any,
P], None],
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None],
vllm_config: Optional[VllmConfig],
env_dict: Optional[dict],
*args: P.args,
@@ -131,7 +126,8 @@ def parallel_launch_with_config(
worker,
vllm_config,
env_dict,
) + args,
)
+ args,
nprocs=world_size,
join=True,
)

View File

@@ -14,28 +14,31 @@ from .common import Config, RankTensors, WeightTensors, make_modular_kernel
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
def do_profile(fn: Callable,
fn_kwargs: dict[Any, Any],
pgi: ProcessGroupInfo,
config: Config,
num_warmups: int = 5):
def do_profile(
fn: Callable,
fn_kwargs: dict[Any, Any],
pgi: ProcessGroupInfo,
config: Config,
num_warmups: int = 5,
):
for _ in range(num_warmups):
fn(**fn_kwargs)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=True,
) as tprof:
fn(**fn_kwargs)
torch.cuda.synchronize(torch.cuda.current_device())
# TODO (varun): Add a descriptive trace file name
tprof.export_chrome_trace(
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json")
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json"
)
def profile_modular_kernel(
@@ -82,6 +85,7 @@ def rank_worker(
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
@@ -108,20 +112,25 @@ def rank_worker(
def run(config: Config):
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights)
parallel_launch_with_config(
config.world_size, rank_worker, vllm_config, env_dict, config, weights
)
if __name__ == '__main__':
if __name__ == "__main__":
from .cli_args import make_config, make_config_arg_parser
parser = make_config_arg_parser(description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
))
parser = make_config_arg_parser(
description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " # noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
)
)
args = parser.parse_args()
assert args.torch_trace_dir_path is not None, (
"Please pass in a directory to store torch traces")
"Please pass in a directory to store torch traces"
)
config = make_config(args)
run(config)

View File

@@ -3,6 +3,7 @@
"""
DeepEP test utilities
"""
import dataclasses
import os
import traceback
@@ -10,17 +11,18 @@ from typing import Callable, Optional
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port, has_deep_ep
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
DeepEPHTPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
DeepEPLLPrepareAndFinalize,
)
## Parallel Processes Utils
@@ -96,7 +98,8 @@ def parallel_launch(
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
) + args,
)
+ args,
nprocs=world_size,
join=True,
)
@@ -118,48 +121,57 @@ class DeepEPLLArgs:
use_fp8_dispatch: bool
def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
def make_deepep_ht_a2a(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
):
import deep_ep
# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
num_dispatchers=pgi.world_size,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts)
buffer = deep_ep.Buffer(
group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank,
)
return DeepEPHTPrepareAndFinalize(
buffer=buffer,
num_dispatchers=pgi.world_size,
dp_size=dp_size,
rank_expert_offset=pgi.rank * ht_args.num_local_experts,
)
def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
def make_deepep_ll_a2a(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
):
import deep_ep
# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)
deepep_ll_args.max_tokens_per_rank,
deepep_ll_args.hidden_size,
pgi.world_size,
deepep_ll_args.num_experts,
)
buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
buffer = deep_ep.Buffer(
group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size,
)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
@@ -169,17 +181,20 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
)
def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
def make_deepep_a2a(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
return make_deepep_ht_a2a(
pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape
)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)

View File

@@ -5,13 +5,14 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
BatchedPrepareAndFinalize,
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
from .test_deepgemm import make_block_quant_fp8_weights
@@ -19,15 +20,15 @@ from .test_deepgemm import make_block_quant_fp8_weights
BLOCK_SIZE = [128, 128]
@pytest.mark.skipif(not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels")
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
@pytest.mark.parametrize("E", [16, 32]) # number of experts
@pytest.mark.parametrize("T", [256, 512]) # tokens per expert
@pytest.mark.parametrize("K", [128, 256]) # hidden dim
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
@pytest.mark.parametrize("topk", [2, 4])
def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
monkeypatch):
def test_batched_deepgemm_vs_triton(
E: int, T: int, K: int, N: int, topk: int, monkeypatch
):
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")

View File

@@ -7,14 +7,18 @@ from typing import Optional
import pytest
import torch
from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, naive_batched_moe)
from tests.kernels.moe.utils import (
batched_moe,
make_quantized_test_activations,
make_test_weights,
naive_batched_moe,
)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
invoke_moe_batched_triton_kernel,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
from vllm.triton_utils import tl
@@ -68,23 +72,32 @@ class BatchedMMTensors:
@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
A = (
torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.in_dtype,
)
/ 10
)
B = torch.randn(
(config.num_experts, config.N, config.K),
device="cuda",
dtype=config.in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.in_dtype)
dtype=config.in_dtype,
)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.out_dtype)
dtype=config.out_dtype,
)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
num_expert_tokens = torch.randint(
low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts,),
device="cuda",
dtype=torch.int32,
)
return BatchedMMTensors(A, B, C, num_expert_tokens)
@@ -96,10 +109,15 @@ class BatchedMMTensors:
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool):
def test_batched_mm(
num_experts: int,
max_tokens_per_expert: int,
K: int,
N: int,
dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool,
):
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
@@ -117,11 +135,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
act_dtype = dtype
quant_dtype = None
num_expert_tokens = torch.randint(low=0,
high=max_tokens_per_expert,
size=(num_experts, ),
device="cuda",
dtype=torch.int32)
num_expert_tokens = torch.randint(
low=0,
high=max_tokens_per_expert,
size=(num_experts,),
device="cuda",
dtype=torch.int32,
)
A, A_q, A_scale = make_quantized_test_activations(
num_experts,
@@ -151,7 +171,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
torch.float32: tl.float32,
}[test_output.dtype]
assert A_q.dtype == B_q.dtype
@@ -173,7 +193,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32,
},
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
@@ -186,11 +206,16 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
num_expert_tokens,
)
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape,
per_act_token_quant)
q_ref_output = native_batched_masked_quant_matmul(
A_q,
B_q,
q_ref_output,
num_expert_tokens,
A_scale,
B_scale,
block_shape,
per_act_token_quant,
)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
@@ -308,12 +333,6 @@ def test_fused_moe_batched_experts(
block_shape=block_shape,
)
torch.testing.assert_close(batched_output,
baseline_output,
atol=3e-2,
rtol=2e-2)
torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2)
torch.testing.assert_close(triton_output,
batched_output,
atol=2e-2,
rtol=2e-2)
torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2)

View File

@@ -5,15 +5,21 @@ import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
_valid_deep_gemm_shape,
deep_gemm_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
fused_topk,
modular_triton_fused_moe,
)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
@@ -24,8 +30,7 @@ if dg_available:
from deep_gemm import get_m_alignment_for_contiguous_layout
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -97,8 +102,7 @@ TOP_KS = [1, 2, 6]
SEEDS = [0]
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
block_shape):
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
topk = topk_ids.size(1)
@@ -114,23 +118,17 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
inter_out = native_w8a8_block_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
out[mask] = native_w8a8_block_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
# Skip all tests if CUDA is not available
@@ -149,8 +147,9 @@ def setup_cuda():
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
monkeypatch):
def test_w8a8_block_fp8_fused_moe(
M, N, K, E, topk, block_size, dtype, seed, monkeypatch
):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
@@ -188,12 +187,9 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
block_size,
)
out = fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config)
out = fused_experts(
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
)
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
@@ -210,8 +206,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch):
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
@@ -245,36 +240,38 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
and current_platform.is_cuda_alike())
use_cudagraph = (
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids, block_size)
ref_out = torch_w8a8_block_fp8_moe(
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size
)
if use_compile:
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
backend="inductor",
fullgraph=True)
deep_gemm_moe_fp8_fn = torch.compile(
deep_gemm_moe_fp8, backend="inductor", fullgraph=True
)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(topk_weights, 0)
torch._dynamo.mark_dynamic(topk_ids, 0)
else:
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
if use_cudagraph:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
out = deep_gemm_moe_fp8_fn(
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids
)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()

View File

@@ -5,16 +5,17 @@ import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
native_w8a8_block_matmul)
from tests.kernels.quant_utils import (
native_per_token_group_quant_int8,
native_w8a8_block_matmul,
)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -77,24 +78,18 @@ def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
inter_out = native_w8a8_block_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(
act_out, block_k)
act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
out[mask] = native_w8a8_block_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
@@ -131,15 +126,19 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
quant_config.w2_scale, score, topk,
block_size)
out = fused_experts(
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
)
ref_out = torch_w8a8_block_int8_moe(
a,
w1,
w2,
quant_config.w1_scale,
quant_config.w2_scale,
score,
topk,
block_size,
)
# Check results
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)

View File

@@ -15,7 +15,6 @@ from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
@dataclasses.dataclass
class TestTensors:
topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor] = None
@@ -25,32 +24,31 @@ class TestTensors:
self.expert_map = self.expert_map.to(device=device)
@staticmethod
def make(num_tokens: int, num_topk: int, num_experts: int, device: str,
topk_ids_dtype: torch.dtype) -> "TestTensors":
def make(
num_tokens: int,
num_topk: int,
num_experts: int,
device: str,
topk_ids_dtype: torch.dtype,
) -> "TestTensors":
# make topk ids
topk_ids = torch.empty((num_tokens, num_topk),
device=device,
dtype=torch.int64)
topk_ids = torch.empty((num_tokens, num_topk), device=device, dtype=torch.int64)
for x in range(num_tokens):
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
topk_ids = topk_ids.to(dtype=torch.int64)
return TestTensors(topk_ids=topk_ids)
def with_ep_rank(self, ep_rank: int, num_global_experts: int,
num_local_experts: int, device: str):
def with_ep_rank(
self, ep_rank: int, num_global_experts: int, num_local_experts: int, device: str
):
# make an expert map
expert_map = torch.empty((num_global_experts),
device=device,
dtype=torch.int32)
expert_map = torch.empty((num_global_experts), device=device, dtype=torch.int32)
expert_map.fill_(-1)
s = ep_rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)),
device=device)
expert_map[s:e] = torch.tensor(list(range(num_local_experts)), device=device)
return TestTensors(topk_ids=self.topk_ids.clone(),
expert_map=expert_map)
return TestTensors(topk_ids=self.topk_ids.clone(), expert_map=expert_map)
def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
@@ -68,49 +66,49 @@ def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
expert_num_tokens[eid] += count
def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
num_experts: int, ep_size: int,
topk_ids_dtype: torch.dtype):
def do_test_compute_expert_num_tokens(
num_tokens: int,
num_topk: int,
num_experts: int,
ep_size: int,
topk_ids_dtype: torch.dtype,
):
assert num_topk <= num_experts
tt = TestTensors.make(num_tokens,
num_topk,
num_experts,
topk_ids_dtype=topk_ids_dtype,
device="cpu")
tt = TestTensors.make(
num_tokens, num_topk, num_experts, topk_ids_dtype=topk_ids_dtype, device="cpu"
)
num_global_experts = num_experts
assert num_global_experts % ep_size == 0
num_local_experts = num_global_experts // ep_size
for ep_rank in range(ep_size):
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts,
num_local_experts, "cpu")
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, num_local_experts, "cpu")
ref_expert_num_tokens = torch.zeros((num_local_experts),
device="cpu",
dtype=torch.int32)
ref_expert_num_tokens = torch.zeros(
(num_local_experts), device="cpu", dtype=torch.int32
)
ref_impl(tt_rank, ref_expert_num_tokens)
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")
tt_rank.to_device("cuda")
# Test with expert_map
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map)
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map
)
# Test without expert map
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
topk_ids, num_local_experts, expert_map=None)
topk_ids, num_local_experts, expert_map=None
)
torch.testing.assert_close(ref_expert_num_tokens,
triton_expert_num_tokens_w_emap,
atol=0,
rtol=0)
torch.testing.assert_close(ref_expert_num_tokens,
triton_expert_num_tokens_wo_emap,
atol=0,
rtol=0)
torch.testing.assert_close(
ref_expert_num_tokens, triton_expert_num_tokens_w_emap, atol=0, rtol=0
)
torch.testing.assert_close(
ref_expert_num_tokens, triton_expert_num_tokens_wo_emap, atol=0, rtol=0
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317])
@@ -118,22 +116,29 @@ def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
@pytest.mark.parametrize("num_experts", [64])
@pytest.mark.parametrize("ep_size", [1, 2, 4])
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
num_experts: int, ep_size: int,
topk_ids_dtype: torch.dtype):
do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts,
ep_size, topk_ids_dtype)
def test_compute_expert_num_tokens(
num_tokens: int,
num_topk: int,
num_experts: int,
ep_size: int,
topk_ids_dtype: torch.dtype,
):
do_test_compute_expert_num_tokens(
num_tokens, num_topk, num_experts, ep_size, topk_ids_dtype
)
@pytest.mark.parametrize("numel", list(range(1, 8192, 111)))
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("ep_size", [2])
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int,
ep_size: int,
topk_ids_dtype: torch.dtype):
do_test_compute_expert_num_tokens(num_tokens=numel,
num_topk=1,
num_experts=num_experts,
ep_size=ep_size,
topk_ids_dtype=topk_ids_dtype)
def test_compute_expert_num_tokens_from_numel(
numel: int, num_experts: int, ep_size: int, topk_ids_dtype: torch.dtype
):
do_test_compute_expert_num_tokens(
num_tokens=numel,
num_topk=1,
num_experts=num_experts,
ep_size=ep_size,
topk_ids_dtype=topk_ids_dtype,
)

View File

@@ -17,19 +17,24 @@ from vllm.utils import cdiv
from vllm.utils.deep_gemm import per_block_cast_to_fp8
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
(4, 8192, 7168, 4096),
(4, 8192, 2048, 7168),
(8, 4096, 7168, 4096),
(8, 4096, 2048, 7168),
(32, 1024, 7168, 4096),
(32, 1024, 2048, 7168),
])
@pytest.mark.parametrize(
"num_groups, expected_m_per_group, k, n",
[
(4, 8192, 7168, 4096),
(4, 8192, 2048, 7168),
(8, 4096, 7168, 4096),
(8, 4096, 2048, 7168),
(32, 1024, 7168, 4096),
(32, 1024, 2048, 7168),
],
)
@pytest.mark.parametrize("out_dtype", [torch.float16])
@pytest.mark.skipif(
(lambda x: x is None or x.to_int() != 100)(
current_platform.get_device_capability()),
reason="Block Scaled Grouped GEMM is only supported on SM100.")
current_platform.get_device_capability()
),
reason="Block Scaled Grouped GEMM is only supported on SM100.",
)
def test_cutlass_grouped_gemm(
num_groups: int,
expected_m_per_group: int,
@@ -40,8 +45,7 @@ def test_cutlass_grouped_gemm(
device = "cuda"
alignment = 128
group_ms = [
int(expected_m_per_group * random.uniform(0.7, 1.3))
for _ in range(num_groups)
int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)
]
m = sum([cdiv(m, alignment) * alignment for m in group_ms])
@@ -58,20 +62,22 @@ def test_cutlass_grouped_gemm(
expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32)
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn),
torch.empty((num_groups, cdiv(n, 128), k // 128),
device=device,
dtype=torch.float))
y_fp8 = (
torch.empty_like(y, dtype=torch.float8_e4m3fn),
torch.empty(
(num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float
),
)
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
for i in range(num_groups):
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]]
a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]]
a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]]
b = y_fp8[0][i].t()
b_scale = y_fp8[1][i].t()
baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype)
ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline
ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline
ops.cutlass_blockwise_scaled_grouped_mm(
out,

View File

@@ -11,13 +11,15 @@ import torch
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
FUSED_MOE_UNQUANTIZED_CONFIG,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8, run_cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
fused_topk)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
cutlass_moe_fp8,
run_cutlass_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.platforms import current_platform
NUM_EXPERTS = [40, 64]
@@ -39,12 +41,11 @@ MNK_FACTORS = [
(224, 3072, 1536),
(32768, 1024, 1024),
# These sizes trigger wrong answers.
#(7232, 2048, 5120),
#(40000, 2048, 5120),
# (7232, 2048, 5120),
# (40000, 2048, 5120),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@@ -60,22 +61,25 @@ class MOETensors:
c_strides2: torch.Tensor
@staticmethod
def make_moe_tensors(m: int, k: int, n: int, e: int,
dtype: torch.dtype) -> "MOETensors":
def make_moe_tensors(
m: int, k: int, n: int, e: int, dtype: torch.dtype
) -> "MOETensors":
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
return MOETensors(a=a,
w1=w1,
w2=w2,
ab_strides1=ab_strides1,
c_strides1=c_strides1,
ab_strides2=ab_strides2,
c_strides2=c_strides2)
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
return MOETensors(
a=a,
w1=w1,
w2=w2,
ab_strides1=ab_strides1,
c_strides1=c_strides1,
ab_strides2=ab_strides2,
c_strides2=c_strides2,
)
@dataclasses.dataclass
@@ -93,9 +97,9 @@ class MOETensors8Bit(MOETensors):
w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d
@staticmethod
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
per_act_token: bool,
per_out_channel: bool) -> "MOETensors8Bit":
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool
) -> "MOETensors8Bit":
dtype = torch.half
q_dtype = torch.float8_e4m3fn
@@ -106,24 +110,21 @@ class MOETensors8Bit(MOETensors):
k_b_scales = k if per_out_channel else 1
# Get the right scale for tests.
a_q, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token
)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
moe_tensors_fp16.w1[expert],
use_per_token_if_dynamic=per_out_channel)
moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel
)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
moe_tensors_fp16.w2[expert],
use_per_token_if_dynamic=per_out_channel)
moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel
)
# a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
a_d = a_q.float().mul(a_scale).to(dtype)
@@ -133,31 +134,37 @@ class MOETensors8Bit(MOETensors):
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
return MOETensors8Bit(a=moe_tensors_fp16.a,
w1=moe_tensors_fp16.w1,
w2=moe_tensors_fp16.w2,
ab_strides1=moe_tensors_fp16.ab_strides1,
c_strides1=moe_tensors_fp16.c_strides1,
ab_strides2=moe_tensors_fp16.ab_strides2,
c_strides2=moe_tensors_fp16.c_strides2,
a_q=a_q,
w1_q=w1_q,
w2_q=w2_q,
a_scale=a_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
a_d=a_d,
w1_d=w1_d,
w2_d=w2_d)
return MOETensors8Bit(
a=moe_tensors_fp16.a,
w1=moe_tensors_fp16.w1,
w2=moe_tensors_fp16.w2,
ab_strides1=moe_tensors_fp16.ab_strides1,
c_strides1=moe_tensors_fp16.c_strides1,
ab_strides2=moe_tensors_fp16.ab_strides2,
c_strides2=moe_tensors_fp16.c_strides2,
a_q=a_q,
w1_q=w1_q,
w2_q=w2_q,
a_scale=a_scale,
w1_scale=w1_scale,
w2_scale=w2_scale,
a_d=a_d,
w1_d=w1_d,
w2_d=w2_d,
)
def run_with_expert_maps(num_experts: int, num_local_experts: int,
**cutlass_moe_kwargs):
def run_with_expert_maps(
num_experts: int, num_local_experts: int, **cutlass_moe_kwargs
):
def slice_experts():
slice_params = [
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
"c_strides2"
"w1_q",
"w2_q",
"ab_strides1",
"ab_strides2",
"c_strides1",
"c_strides2",
]
full_tensors = {
k: v
@@ -173,9 +180,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
# make expert map
expert_map = [-1] * num_experts
expert_map[s:e] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map,
dtype=torch.int32,
device="cuda")
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
# update cutlass moe arg with expert_map
cutlass_moe_kwargs["expert_map"] = expert_map
@@ -198,18 +203,26 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
return out_tensor
def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([
t is None for t in [
moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale,
moe_tensors.w2_scale, moe_tensors.a_scale
def run_8_bit(
moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
num_local_experts: Optional[int] = None,
) -> torch.Tensor:
assert not any(
[
t is None
for t in [
moe_tensors.w1_q,
moe_tensors.w2_q,
moe_tensors.w1_scale,
moe_tensors.w2_scale,
moe_tensors.a_scale,
]
]
])
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=moe_tensors.w1_scale,
@@ -222,16 +235,16 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
)
kwargs = {
'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'quant_config': quant_config,
"a": moe_tensors.a,
"w1_q": moe_tensors.w1_q, # type: ignore[union-attr]
"w2_q": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"ab_strides1": moe_tensors.ab_strides1,
"ab_strides2": moe_tensors.ab_strides2,
"c_strides1": moe_tensors.c_strides1,
"c_strides2": moe_tensors.c_strides2,
"quant_config": quant_config,
}
num_experts = moe_tensors.w1.size(0)
@@ -243,7 +256,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
return run_with_expert_maps(
num_experts,
num_local_experts, # type: ignore[arg-type]
**kwargs)
**kwargs,
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@@ -253,8 +267,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_moe_8_bit_no_graph(
m: int,
n: int,
@@ -269,25 +285,18 @@ def test_cutlass_moe_8_bit_no_graph(
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
triton_output = fused_experts(
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
)
if ep_size is not None:
assert e % ep_size == 0, "Cannot distribute experts evenly"
@@ -295,15 +304,15 @@ def test_cutlass_moe_8_bit_no_graph(
else:
number_local_experts = None
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
per_out_ch, number_local_experts)
cutlass_output = run_8_bit(
mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts
)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
torch.testing.assert_close(triton_output,
cutlass_output,
atol=5.5e-2,
rtol=1e-2)
torch.testing.assert_close(
triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@@ -313,8 +322,10 @@ def test_cutlass_moe_8_bit_no_graph(
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_moe_8_bit_cuda_graph(
m: int,
n: int,
@@ -330,39 +341,30 @@ def test_cutlass_moe_8_bit_cuda_graph(
with set_current_vllm_config(vllm_config):
dtype = torch.half
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
# Note that we are using the dequantized versions of the tensors.
# Using a, w1 and w2 directly results in minor output differences.
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
triton_output = fused_experts(mt.a_d,
mt.w1_d,
mt.w2_d,
topk_weights,
topk_ids,
quant_config=quant_config)
triton_output = fused_experts(
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token, per_out_ch)
cutlass_output = run_8_bit(
mt, topk_weights, topk_ids, per_act_token, per_out_ch
)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(triton_output,
cutlass_output,
atol=9e-2,
rtol=1e-2)
torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2)
@pytest.mark.parametrize("m", [64])
@@ -375,8 +377,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_moe_8_bit_EP(
m: int,
n: int,
@@ -388,8 +392,9 @@ def test_cutlass_moe_8_bit_EP(
ep_size: int,
monkeypatch,
):
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
per_out_channel, monkeypatch, ep_size)
test_cutlass_moe_8_bit_no_graph(
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
)
LARGE_MNK_FACTORS = [
@@ -406,8 +411,10 @@ LARGE_MNK_FACTORS = [
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_moe_8_bit_EP_large(
m: int,
n: int,
@@ -419,8 +426,9 @@ def test_cutlass_moe_8_bit_EP_large(
ep_size: int,
monkeypatch,
):
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
per_out_channel, monkeypatch, ep_size)
test_cutlass_moe_8_bit_no_graph(
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
)
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
@@ -430,8 +438,10 @@ def test_cutlass_moe_8_bit_EP_large(
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_run_cutlass_moe_fp8(
m: int,
n: int,
@@ -444,14 +454,12 @@ def test_run_cutlass_moe_fp8(
):
current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)
mt = MOETensors8Bit.make_moe_tensors_8bit(
m, k, n, e, per_act_token, per_out_channel
)
score = torch.randn((m, e), device="cuda", dtype=torch.half)
topk_weights, topk_ids, _ = fused_topk(mt.a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
# we want to make sure there is at least one token that's generated in
# this expert shard and at least one token that's NOT generated in this
# expert shard
@@ -462,12 +470,12 @@ def test_run_cutlass_moe_fp8(
workspace2_shape = (m * topk, max(n, k))
output_shape = (m, k)
workspace13 = torch.empty(prod(workspace13_shape),
device="cuda",
dtype=mt.a.dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device="cuda",
dtype=mt.a.dtype)
workspace13 = torch.empty(
prod(workspace13_shape), device="cuda", dtype=mt.a.dtype
)
workspace2 = torch.empty(
prod(workspace2_shape), device="cuda", dtype=mt.a.dtype
)
num_local_experts = e // ep_size
start, end = 0, num_local_experts
@@ -475,36 +483,55 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
per_act_token)
a1q, a1q_scale = moe_kernel_quantize_input(
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
)
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False, topk_weights)
output,
a1q,
mt.w1_q,
mt.w2_q,
topk_ids,
activation,
global_num_experts,
expert_map,
mt.w1_scale,
mt.w2_scale,
a1q_scale,
None,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
workspace13,
workspace2,
None,
mt.a.dtype,
per_act_token,
per_out_channel,
False,
topk_weights,
)
workspace13.random_()
output_random_workspace = torch.empty(output_shape,
device="cuda",
dtype=mt.a.dtype)
output_random_workspace = torch.empty(
output_shape, device="cuda", dtype=mt.a.dtype
)
func(output_random_workspace)
workspace13.fill_(0)
output_zero_workspace = torch.zeros(output_shape,
device="cuda",
dtype=mt.a.dtype)
output_zero_workspace = torch.zeros(
output_shape, device="cuda", dtype=mt.a.dtype
)
func(output_zero_workspace)
torch.testing.assert_close(output_random_workspace,
output_zero_workspace,
atol=5e-3,
rtol=1e-3)
torch.testing.assert_close(
output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3
)

View File

@@ -16,10 +16,11 @@ from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
@@ -30,18 +31,19 @@ from .utils import make_test_weights
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
DeepEPHTPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
DeepEPLLPrepareAndFinalize,
)
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm():
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep(),
@@ -58,9 +60,10 @@ P = ParamSpec("P")
def next_power_of_2(x):
import math
if x == 0:
return 1
return 2**math.ceil(math.log2(x))
return 2 ** math.ceil(math.log2(x))
def make_block_quant_fp8_weights(
@@ -72,13 +75,9 @@ def make_block_quant_fp8_weights(
"""
Return weights w1q, w2q, w1_scale, w2_scale
"""
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
_) = make_test_weights(e,
n,
k,
torch.bfloat16,
torch.float8_e4m3fn,
block_shape=block_size)
(_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
)
return w1q, w2q, w1_scale, w2_scale
@@ -106,15 +105,15 @@ class TestTensors:
@staticmethod
def make(config: TestConfig, rank) -> "TestTensors":
dtype = torch.bfloat16
topk, m, k = (config.topk, config.m, config.k)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
rank_tokens = torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
rank_tokens = (
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
)
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
rank_token_scales = None
@@ -122,25 +121,32 @@ class TestTensors:
low=0,
high=config.num_experts,
size=(m, topk),
device=torch.cuda.current_device()).to(dtype=torch.int64)
device=torch.cuda.current_device(),
).to(dtype=torch.int64)
topk_weights = torch.randn(topk_ids.shape,
dtype=torch.float32,
device=torch.cuda.current_device())
topk_weights = torch.randn(
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
)
return TestTensors(rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk_ids,
topk_weights=topk_weights,
config=config)
return TestTensors(
rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk_ids,
topk_weights=topk_weights,
config=config,
)
def make_ll_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
pg: ProcessGroup,
pgi: ProcessGroupInfo,
max_tokens_per_rank: int,
dp_size: int,
hidden_size: int,
q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None
@@ -153,26 +159,30 @@ def make_ll_modular_kernel(
max_tokens_per_rank=max_tokens_per_rank,
hidden_size=hidden_size,
num_experts=test_config.num_experts,
use_fp8_dispatch=test_config.use_fp8_dispatch),
use_fp8_dispatch=test_config.use_fp8_dispatch,
),
q_dtype=q_dtype,
block_shape=test_config.block_size)
block_shape=test_config.block_size,
)
fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank,
num_dispatchers=pgi.world_size // dp_size,
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
def make_ht_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
num_local_experts: int,
q_dtype: Optional[torch.dtype],
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None
@@ -183,76 +193,82 @@ def make_ht_modular_kernel(
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
deepep_ll_args=None,
q_dtype=q_dtype,
block_shape=test_config.block_size)
block_shape=test_config.block_size,
)
fused_experts = DeepGemmExperts(quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
def make_modular_kernel(
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
num_local_experts: int, test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
num_local_experts: int,
test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config
mk: FusedMoEModularKernel
# Make modular kernel
if test_config.low_latency:
max_tokens_per_rank = max(
64, next_power_of_2(test_tensors.rank_tokens.size(0)))
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
hidden_size = test_tensors.rank_tokens.size(-1)
mk = make_ll_modular_kernel(pg=pg,
pgi=pgi,
max_tokens_per_rank=max_tokens_per_rank,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config,
quant_config=quant_config)
mk = make_ll_modular_kernel(
pg=pg,
pgi=pgi,
max_tokens_per_rank=max_tokens_per_rank,
dp_size=dp_size,
hidden_size=hidden_size,
q_dtype=q_dtype,
test_config=test_config,
quant_config=quant_config,
)
else:
mk = make_ht_modular_kernel(pg,
pgi,
dp_size,
num_local_experts,
q_dtype,
test_config,
quant_config=quant_config)
mk = make_ht_modular_kernel(
pg,
pgi,
dp_size,
num_local_experts,
q_dtype,
test_config,
quant_config=quant_config,
)
return mk
def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size: int, test_tensors: TestTensors,
w1: torch.Tensor, w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor]) -> torch.Tensor:
def deepep_deepgemm_moe_impl(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
) -> torch.Tensor:
test_config = test_tensors.config
num_experts = test_config.num_experts
num_local_experts = w1.size(0)
def build_expert_map():
num_local_experts = w1.size(0)
expert_map = torch.full((num_experts, ),
fill_value=-1,
dtype=torch.int32)
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
# Low-Latency kernels can't dispatch scales.
a1_scale=(None if test_config.low_latency else
test_tensors.rank_token_scales),
a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
block_shape=test_config.block_size,
)
@@ -263,26 +279,35 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
dp_size=dp_size,
num_local_experts=num_local_experts,
test_tensors=test_tensors,
quant_config=quant_config)
quant_config=quant_config,
)
out = mk.forward(hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False)
out = mk.forward(
hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
)
return out
def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor, block_shape: list[int]):
def triton_impl(
a: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: list[int],
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
@@ -300,7 +325,8 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
quant_config=quant_config,
# Make sure this is set to False so we
# don't end up comparing the same implementation.
allow_deep_gemm=False)
allow_deep_gemm=False,
)
def _test_deepep_deepgemm_moe(
@@ -321,22 +347,21 @@ def _test_deepep_deepgemm_moe(
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank)
block_shape = [
w1.size(1) // w1_scale.size(1),
w1.size(2) // w1_scale.size(2)
]
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
with set_current_vllm_config(VllmConfig()):
# Reference
triton_moe = triton_impl(a=test_tensors.rank_tokens,
topk_ids=test_tensors.topk,
topk_weights=test_tensors.topk_weights,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=test_tensors.rank_token_scales,
block_shape=block_shape)
triton_moe = triton_impl(
a=test_tensors.rank_tokens,
topk_ids=test_tensors.topk,
topk_weights=test_tensors.topk_weights,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=test_tensors.rank_token_scales,
block_shape=block_shape,
)
# Slice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
@@ -390,10 +415,15 @@ NUM_EXPERTS = [32]
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
topk: int, world_dp_size: tuple[int, int]):
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ht_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
):
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
"""
@@ -409,21 +439,32 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
block_size = [block_m, block_m]
world_size, dp_size = world_dp_size
config = TestConfig(topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=False,
use_fp8_dispatch=None)
config = TestConfig(
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=False,
use_fp8_dispatch=None,
)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)
num_experts, n, k, block_size
)
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
w2, w1_scale, w2_scale)
parallel_launch(
world_size,
_test_deepep_deepgemm_moe,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
)
MNKs = [
@@ -448,8 +489,9 @@ USE_FP8_DISPATCH = [False]
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
reason="Skipping test for Blackwell DeepGEMM")
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
@@ -482,7 +524,16 @@ def test_ll_deepep_deepgemm_moe(
)
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
num_experts, n, k, block_size)
num_experts, n, k, block_size
)
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
w2, w1_scale, w2_scale)
parallel_launch(
world_size,
_test_deepep_deepgemm_moe,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
)

View File

@@ -16,12 +16,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
per_token_group_quant_fp8,
)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep
@@ -30,9 +29,11 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
DeepEPHTPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
DeepEPLLPrepareAndFinalize,
)
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
@@ -45,7 +46,7 @@ MAX_TOKENS_PER_RANK = 64
def make_weights(
e, n, k, dtype
e, n, k, dtype
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1_scale, w2_scale
@@ -64,17 +65,15 @@ def make_weights(
k_b_scales = k
w1_q = torch.empty_like(w1, dtype=dtype)
w2_q = torch.empty_like(w2, dtype=dtype)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=True)
w1[expert], use_per_token_if_dynamic=True
)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=True)
w2[expert], use_per_token_if_dynamic=True
)
return w1_q, w2_q, w1_scale, w2_scale
@@ -100,24 +99,25 @@ class TestTensors:
def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors":
# TODO (varun) - check that float16 works ?
assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn]
token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn
else config.dtype)
rank_tokens = torch.randn(
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
token_dtype = (
torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype
)
rank_tokens = (
torch.randn((config.m, config.k), device="cuda", dtype=token_dtype) / 10
)
rank_token_scales = None
topk = torch.randint(low=0,
high=config.num_experts,
size=(config.m, config.topk),
device="cuda").to(dtype=torch.int64)
topk_weights = torch.randn(topk.shape,
dtype=torch.float32,
device="cuda")
return TestTensors(rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk,
topk_weights=topk_weights,
config=config)
topk = torch.randint(
low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda"
).to(dtype=torch.int64)
topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda")
return TestTensors(
rank_tokens=rank_tokens,
rank_token_scales=rank_token_scales,
topk=topk,
topk_weights=topk_weights,
config=config,
)
def make_modular_kernel(
@@ -132,28 +132,33 @@ def make_modular_kernel(
use_fp8_dispatch: bool,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
ht_args: Optional[DeepEPHTArgs] = None
ll_args: Optional[DeepEPLLArgs] = None
if low_latency_mode:
ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK,
hidden_size=hidden_size,
num_experts=num_experts,
use_fp8_dispatch=use_fp8_dispatch)
ll_args = DeepEPLLArgs(
max_tokens_per_rank=MAX_TOKENS_PER_RANK,
hidden_size=hidden_size,
num_experts=num_experts,
use_fp8_dispatch=use_fp8_dispatch,
)
else:
assert not use_fp8_dispatch, (
"FP8 Dispatch is valid only for low-latency kernels")
"FP8 Dispatch is valid only for low-latency kernels"
)
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \
make_deepep_a2a(pg = pg,
pgi = pgi,
dp_size = dp_size,
q_dtype = q_dtype,
block_shape = None,
deepep_ht_args = ht_args,
deepep_ll_args = ll_args)
a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = (
make_deepep_a2a(
pg=pg,
pgi=pgi,
dp_size=dp_size,
q_dtype=q_dtype,
block_shape=None,
deepep_ht_args=ht_args,
deepep_ll_args=ll_args,
)
)
num_dispatchers = pgi.world_size // dp_size
@@ -167,8 +172,7 @@ def make_modular_kernel(
else:
fused_experts = TritonExperts(quant_config=quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
@@ -186,19 +190,15 @@ def deep_ep_moe_impl(
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> torch.Tensor:
num_local_experts = w1.size(0)
def build_expert_map():
num_local_experts = w1.size(0)
expert_map = torch.full((num_experts, ),
fill_value=-1,
dtype=torch.int32)
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
s = pgi.rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
return expert_map.to(device=torch.cuda.current_device(),
dtype=torch.int32)
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
hidden_size = test_tensors.rank_tokens.size(1)
is_quantized = w1.dtype == torch.float8_e4m3fn
@@ -214,11 +214,12 @@ def deep_ep_moe_impl(
topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end]
topk_chunk = test_tensors.topk[chunk_start:chunk_end]
rank_token_scales_chunk = test_tensors.rank_token_scales
if rank_token_scales_chunk is not None and rank_token_scales_chunk.size(
0) == total_num_tokens:
if (
rank_token_scales_chunk is not None
and rank_token_scales_chunk.size(0) == total_num_tokens
):
# per act token
rank_token_scales_chunk = rank_token_scales_chunk[
chunk_start:chunk_end]
rank_token_scales_chunk = rank_token_scales_chunk[chunk_start:chunk_end]
quant_config = FusedMoEQuantConfig.make(
q_dtype,
@@ -230,26 +231,37 @@ def deep_ep_moe_impl(
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
pg,
pgi,
low_latency_mode,
hidden_size,
dp_size,
num_experts,
num_local_experts,
q_dtype,
use_fp8_dispatch,
quant_config,
)
out = mk.forward(hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False)
out = mk.forward(
hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
)
if not skip_result_store:
out_hidden_states[chunk_start:chunk_end, :].copy_(
out, non_blocking=True)
out_hidden_states[chunk_start:chunk_end, :].copy_(out, non_blocking=True)
max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK
if low_latency_mode else total_num_tokens)
max_num_tokens_per_dp = (
MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens
)
for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp):
chunk_start = chunk_start_
@@ -258,9 +270,9 @@ def deep_ep_moe_impl(
chunk_start = min(chunk_start, total_num_tokens - 1)
chunk_end = min(chunk_end, total_num_tokens)
process_chunk(chunk_start,
chunk_end,
skip_result_store=chunk_start_ >= total_num_tokens)
process_chunk(
chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens
)
return out_hidden_states
@@ -274,9 +286,11 @@ def torch_moe_impl(
using_fp8_dispatch: bool,
per_act_token_quant: bool,
):
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
test_tensors.topk_weights)
a, topk_ids, topk_weights = (
test_tensors.rank_tokens,
test_tensors.topk,
test_tensors.topk_weights,
)
if using_fp8_dispatch:
# The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by
@@ -284,8 +298,11 @@ def torch_moe_impl(
assert not per_act_token_quant
a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128)
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
a.shape).to(a.dtype)
a = (
(aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
.view(a.shape)
.to(a.dtype)
)
is_quantized = w1.dtype == torch.float8_e4m3fn
a_dtype = a.dtype
@@ -306,8 +323,9 @@ def torch_moe_impl(
e_w = topk_weights[i][j]
w1_e = w1[e]
w2_e = w2[e]
o_i += (SiluAndMul()
(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w
o_i += (
SiluAndMul()(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)
) * e_w
if is_quantized:
out = out.to(dtype=a_dtype)
@@ -327,28 +345,36 @@ def _deep_ep_moe(
use_fp8_dispatch: bool,
per_act_token_quant: bool,
):
if not low_latency_mode:
assert not use_fp8_dispatch, (
"FP8 dispatch interface is available only in low-latency mode")
"FP8 dispatch interface is available only in low-latency mode"
)
is_quantized = w1.dtype == torch.float8_e4m3fn
w1 = w1.to(device=torch.cuda.current_device())
w2 = w2.to(device=torch.cuda.current_device())
if is_quantized:
w1_scale = w1_scale.to( # type: ignore
device=torch.cuda.current_device())
device=torch.cuda.current_device()
)
w2_scale = w2_scale.to( # type: ignore
device=torch.cuda.current_device())
device=torch.cuda.current_device()
)
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, low_latency_mode)
with set_current_vllm_config(VllmConfig()):
# Reference
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
w2_scale, use_fp8_dispatch,
per_act_token_quant)
torch_combined = torch_moe_impl(
test_tensors,
w1,
w2,
w1_scale,
w2_scale,
use_fp8_dispatch,
per_act_token_quant,
)
# Splice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
@@ -420,18 +446,23 @@ def test_deep_ep_moe(
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype,
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts)
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
per_act_token_quant)
parallel_launch(
world_size,
_deep_ep_moe,
low_latency_mode,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
use_fp8_dispatch,
per_act_token_quant,
)
MNKs = [
@@ -467,8 +498,7 @@ def test_low_latency_deep_ep_moe(
):
low_latency_mode = True
if (low_latency_mode
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
if low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES:
pytest.skip(
f"Skipping test as hidden size {k} is not in list of supported "
f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
@@ -476,15 +506,20 @@ def test_low_latency_deep_ep_moe(
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
config = TestConfig(dtype=dtype,
topk=topk,
m=m,
k=k,
n=n,
num_experts=num_experts)
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
False)
parallel_launch(
world_size,
_deep_ep_moe,
low_latency_mode,
dp_size,
config,
w1,
w2,
w1_scale,
w2_scale,
use_fp8_dispatch,
False,
)

View File

@@ -11,14 +11,18 @@ import math
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported,
per_block_cast_to_fp8)
per_token_group_quant_fp8,
)
from vllm.utils.deep_gemm import (
calc_diff,
is_deep_gemm_supported,
per_block_cast_to_fp8,
)
BLOCK_SIZE = [128, 128]
@@ -37,8 +41,10 @@ def make_block_quant_fp8_weights(
w2 shape: (E, K, N)
"""
dtype = torch.bfloat16
fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo(
torch.float8_e4m3fn).min
fp8_max, fp8_min = (
torch.finfo(torch.float8_e4m3fn).max,
torch.finfo(torch.float8_e4m3fn).min,
)
# bf16 reference weights
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
@@ -54,24 +60,16 @@ def make_block_quant_fp8_weights(
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty(e,
n_tiles_w1,
k_tiles_w1,
device="cuda",
dtype=torch.float32)
w2_s = torch.empty(e,
n_tiles_w2,
k_tiles_w2,
device="cuda",
dtype=torch.float32)
w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32)
w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32)
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size=block_size,
use_ue8m0=True)
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size=block_size,
use_ue8m0=True)
w1[i], w1_s[i] = per_block_cast_to_fp8(
w1_bf16[i], block_size=block_size, use_ue8m0=True
)
w2[i], w2_s[i] = per_block_cast_to_fp8(
w2_bf16[i], block_size=block_size, use_ue8m0=True
)
return w1, w2, w1_s, w2_s
@@ -81,18 +79,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
Triton baseline within tolerance.
"""
tokens_bf16 = torch.randn(
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
tokens_bf16 = (
torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
.clamp_min_(-1)
.clamp_max_(1)
)
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
# expert weight tensors
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
block_size)
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size)
router_logits = torch.randn(m,
num_experts,
device="cuda",
dtype=torch.float32)
router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
@@ -147,15 +144,14 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(),
reason="Requires deep_gemm kernels")
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_fused_moe_mod = importlib.import_module(
"vllm.model_executor.layers.fused_moe.fused_moe")
"vllm.model_executor.layers.fused_moe.fused_moe"
)
call_counter = {"cnt": 0}
@@ -165,8 +161,7 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
call_counter["cnt"] += 1
return orig_fn(*args, **kwargs)
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
_spy_deep_gemm_moe_fp8)
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")
@@ -181,6 +176,7 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
)
# ensure that the DeepGEMM path was indeed taken.
assert call_counter["cnt"] == 1, \
f"DeepGEMM path was not executed during the test. " \
assert call_counter["cnt"] == 1, (
f"DeepGEMM path was not executed during the test. "
f"Call counter: {call_counter['cnt']}"
)

View File

@@ -6,24 +6,28 @@ import pytest
import torch
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
input_to_float8)
apply_flashinfer_per_tensor_scale_fp8,
flashinfer_cutlass_moe_fp8,
register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
100
):
pytest.skip(
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True,
)
NUM_EXPERTS = [16]
TOP_KS = [1]
@@ -39,8 +43,7 @@ MNK_FACTORS = [
(1, 4096, 5120),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@@ -74,18 +77,17 @@ class TestData:
layer: torch.nn.Module
@staticmethod
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
reorder: bool) -> "TestData":
hidden_states = torch.randn(
(m, k), device="cuda", dtype=torch.bfloat16) / 10
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, reorder: bool
) -> "TestData":
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
# Scale to fp8
_, a1_scale = input_to_float8(hidden_states)
a1_scale = 1.0 / a1_scale
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
dtype=torch.float32)
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
@@ -102,8 +104,7 @@ class TestData:
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if reorder:
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
@@ -145,7 +146,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
@@ -178,12 +180,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
top_k=topk,
num_expert_group=None,
topk_group=None,
apply_router_weight_on_input=True)
apply_router_weight_on_input=True,
)
torch.testing.assert_close(output,
flashinfer_output,
atol=5.5e-2,
rtol=1e-2)
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
@pytest.mark.skip(
@@ -213,7 +213,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
top_k=topk,
renormalize=False,
custom_routing_function=Llama4MoE.custom_routing_function,
scoring_func="softmax")
scoring_func="softmax",
)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=td.w13_weight_scale,
@@ -250,7 +251,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
apply_router_weight_on_input=True,
)
torch.testing.assert_close(output,
flashinfer_cutlass_output,
atol=5.5e-2,
rtol=1e-2)
torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
)

View File

@@ -4,26 +4,33 @@ import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
100
):
pytest.skip(
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True,
)
MNK_FACTORS = [
(2, 1024, 1024),
@@ -44,13 +51,13 @@ MNK_FACTORS = [
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
def test_flashinfer_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
@@ -66,10 +73,7 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
@@ -87,16 +91,19 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize,
)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
@@ -104,23 +111,26 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
quant_config.w1_scale[idx],
(1 / quant_config.g1_alphas[idx]),
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
block_size=quant_blocksize,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
quant_config.w2_scale[idx],
(1 / quant_config.g2_alphas[idx]),
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
block_size=quant_blocksize,
)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output,
flashinfer_output,
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1
)
if __name__ == "__main__":

View File

@@ -17,20 +17,21 @@ if not has_triton_kernels():
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp,
upcast_from_mxfp)
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
BatchedPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
BatchedOAITritonExperts, triton_kernel_moe_forward)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
BatchedOAITritonExperts,
triton_kernel_moe_forward,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.utils import shuffle_weight
from vllm.utils import round_up
@@ -46,13 +47,11 @@ def deshuffle(w: torch.Tensor):
def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
randbits = [torch.randperm(E) for _ in range(M)]
x_list = [
(-1)**i *
((16384 +
((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
(-1) ** i
* ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
for i, bits in enumerate(randbits)
]
exp_data = torch.stack(x_list).to(
device="cuda") # simulating gate_output (M, E)
exp_data = torch.stack(x_list).to(device="cuda") # simulating gate_output (M, E)
# create input tensor
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
@@ -120,20 +119,21 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
value=0,
)
w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
mode="constant",
value=0)
w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0),
mode="constant",
value=0)
w1_bias_tri = F.pad(
w1_bias_tri, (0, w1_right_pad, 0, 0), mode="constant", value=0
)
w2_bias_tri = F.pad(
w2_bias_tri, (0, w2_right_pad, 0, 0), mode="constant", value=0
)
x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=1)
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
w_scale_layout, w_scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps))
mx_axis=1, num_warps=num_warps
)
)
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)
@@ -141,29 +141,33 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
**w_layout_opts)
w1_tri = convert_layout(
wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts
)
w1_scale_tri = convert_layout(
wrap_torch_tensor(w1_scale_tri),
w_scale_layout,
**w_scale_layout_opts,
)
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
**w_layout_opts)
w2_tri = convert_layout(
wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts
)
w2_scale_tri = convert_layout(
wrap_torch_tensor(w2_scale_tri),
w_scale_layout,
**w_scale_layout_opts,
)
pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
flex_ctx=FlexCtx(rhs_data=InFlexData()))
pc2 = PrecisionConfig(weight_scale=w2_scale_tri,
flex_ctx=FlexCtx(rhs_data=InFlexData()))
pc1 = PrecisionConfig(
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
)
pc2 = PrecisionConfig(
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
)
# tucuate so the rest can run properly
w1 = w1[..., :K, :2 * N]
w1 = w1[..., :K, : 2 * N]
w2 = w2[..., :N, :K]
w1 = deshuffle(w1)
@@ -261,7 +265,8 @@ class Case:
@pytest.mark.parametrize(
", ".join(f.name for f in fields(Case)),
[
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
tuple(getattr(case, f.name) for f in fields(Case))
for case in [
# Case(a_dtype="bf16", w_dtype="bf16"),
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
Case(a_dtype="bf16", w_dtype="mx4")
@@ -321,10 +326,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
gating_output=exp_data,
topk=topk,
)
assert_close(ref=out_ref,
tri=out_triton_monolithic,
maxtol=0.025,
rmstol=0.005)
assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005)
def batched_moe(
@@ -376,7 +378,8 @@ def batched_moe(
@pytest.mark.parametrize(
", ".join(f.name for f in fields(Case)),
[
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
tuple(getattr(case, f.name) for f in fields(Case))
for case in [
# Case(a_dtype="bf16", w_dtype="bf16"),
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
Case(a_dtype="bf16", w_dtype="mx4")

View File

@@ -4,16 +4,20 @@
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk,
grouped_topk)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_grouped_topk,
grouped_topk,
)
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("n_token", [1, 33, 64])
@pytest.mark.parametrize("n_hidden", [1024, 2048])
@pytest.mark.parametrize("n_expert", [16])
@@ -23,23 +27,26 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])
def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
n_hidden: int, n_expert: int, topk: int,
renormalize: bool, num_expert_group: int,
topk_group: int, scoring_func: str,
routed_scaling_factor: float, dtype: torch.dtype):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_grouped_topk(
monkeypatch: pytest.MonkeyPatch,
n_token: int,
n_hidden: int,
n_expert: int,
topk: int,
renormalize: bool,
num_expert_group: int,
topk_group: int,
scoring_func: str,
routed_scaling_factor: float,
dtype: torch.dtype,
):
current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden),
dtype=dtype,
device="cuda")
gating_output = torch.randn((n_token, n_expert),
dtype=dtype,
device="cuda")
e_score_correction_bias = torch.randn((n_expert, ),
dtype=torch.float32,
device="cuda")
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
e_score_correction_bias = torch.randn(
(n_expert,), dtype=torch.float32, device="cuda"
)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
@@ -52,7 +59,8 @@ def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
)
test_topk_weights, test_topk_ids = fused_grouped_topk(
hidden_states=hidden_states,
@@ -63,14 +71,11 @@ def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
)
if renormalize:
torch.testing.assert_close(baseline_topk_weights,
test_topk_weights,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_topk_ids,
test_topk_ids,
atol=0,
rtol=0)
torch.testing.assert_close(
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
)
torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)

View File

@@ -17,18 +17,29 @@ from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from ...utils import multi_gpu_test
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
reference_moe_impl,
run_modular_kernel)
from .modular_kernel_tools.common import (
Config,
RankTensors,
WeightTensors,
reference_moe_impl,
run_modular_kernel,
)
from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
expert_info)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config)
MK_FUSED_EXPERT_TYPES,
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS,
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
TestMoEQuantConfig,
expert_info,
)
from .modular_kernel_tools.parallel_utils import (
ProcessGroupInfo,
parallel_launch_with_config,
)
has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
or has_flashinfer_cutlass_fused_moe())
has_any_multi_gpu_package = (
has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
)
meets_multi_gpu_requirements = pytest.mark.skipif(
not has_any_multi_gpu_package,
@@ -64,9 +75,9 @@ def rank_worker(
# sanity check
from vllm import envs
if base_config.fused_moe_chunk_size is not None:
assert (
base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device
weights.to_current_device()
@@ -93,8 +104,7 @@ def rank_worker(
rank_tensors = RankTensors.make(config, pgi)
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
rank_tensors)
mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(config, weights, rank_tensors)
@@ -115,10 +125,10 @@ def rank_worker(
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
f"rank={pgi.rank}."
)
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
def run(config: Config, verbose: bool):
@@ -127,8 +137,9 @@ def run(config: Config, verbose: bool):
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights, verbose)
parallel_launch_with_config(
config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
)
Ms = [32, 64]
@@ -149,8 +160,9 @@ def is_nyi_config(config: Config) -> bool:
if info.needs_matching_quant:
# The triton kernels expect both per-act-token-quant and
# per-out-ch-quant or neither.
unsupported_quant_config = ((config.is_per_act_token_quant +
config.is_per_out_ch_quant) == 1)
unsupported_quant_config = (
config.is_per_act_token_quant + config.is_per_out_ch_quant
) == 1
return unsupported_quant_config
return not info.supports_expert_map
@@ -162,19 +174,25 @@ def is_nyi_config(config: Config) -> bool:
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
"combination",
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
"combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
)
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
@multi_gpu_test(num_gpus=2)
@meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
k: int,
n: int,
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size: Optional[int],
world_size: int,
pytestconfig,
):
config = Config(
Ms=Ms,
K=k,
@@ -195,7 +213,7 @@ def test_modular_kernel_combinations_multigpu(
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
verbosity = pytestconfig.getoption('verbose')
verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0)
@@ -205,16 +223,23 @@ def test_modular_kernel_combinations_multigpu(
@pytest.mark.parametrize("dtype", DTYPEs)
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
@pytest.mark.parametrize(
"combination",
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
"combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
)
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
k: int,
n: int,
e: int,
dtype: torch.dtype,
quant_config: Optional[TestMoEQuantConfig],
combination: tuple[
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
],
fused_moe_chunk_size: Optional[int],
world_size: int,
pytestconfig,
):
config = Config(
Ms=Ms,
K=k,
@@ -235,19 +260,21 @@ def test_modular_kernel_combinations_singlegpu(
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
verbosity = pytestconfig.getoption('verbose')
verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0)
if __name__ == '__main__':
if __name__ == "__main__":
# Ability to test individual PrepareAndFinalize and FusedExperts combination
from .modular_kernel_tools.cli_args import (make_config,
make_config_arg_parser)
parser = make_config_arg_parser(description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
))
from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser
parser = make_config_arg_parser(
description=(
"Run single prepare-finalize & fused-experts combination test"
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " # noqa: E501
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
)
)
args = parser.parse_args()
config = make_config(args)

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import functools
from typing import Callable, Optional, Union
@@ -21,22 +22,32 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config)
FUSED_MOE_UNQUANTIZED_CONFIG,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
fused_topk,
modular_triton_fused_moe,
)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
fused_moe as iterative_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias)
marlin_permute_bias,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like)
rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
marlin_quant_fp8_torch,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
awq_marlin_quantize,
marlin_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -87,13 +98,15 @@ def run_moe_test(
if isinstance(baseline, torch.Tensor):
baseline_output = baseline
else:
baseline_output = baseline(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
baseline_output = baseline(
a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
# Pad the weight if moe padding is enabled
if padding:
@@ -105,34 +118,35 @@ def run_moe_test(
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(score, 0)
test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
test_output = moe_fn(
a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
if use_cudagraph:
test_output.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
test_output = moe_fn(
a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(test_output,
baseline_output,
atol=atol,
rtol=rtol)
torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
return baseline_output
@@ -176,11 +190,8 @@ def test_fused_moe(
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device="cuda",
dtype=torch.int32)
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
@@ -204,13 +215,15 @@ def test_fused_moe(
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map)
return m_fused_moe_fn(
a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
@@ -234,19 +247,22 @@ def test_fused_moe(
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (n >= 1024 and k >= 1024
and current_platform.is_cuda_alike())
use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
with set_current_vllm_config(vllm_config):
baseline_output = runner(torch_moe, iterative_moe)
runner(baseline_output,
fused_moe_fn,
use_compile=use_compile,
use_cudagraph=use_cudagraph)
runner(baseline_output,
m_fused_moe,
use_compile=use_compile,
use_cudagraph=use_cudagraph)
runner(
baseline_output,
fused_moe_fn,
use_compile=use_compile,
use_cudagraph=use_cudagraph,
)
runner(
baseline_output,
m_fused_moe,
use_compile=use_compile,
use_cudagraph=use_cudagraph,
)
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
@@ -257,9 +273,18 @@ def test_fused_moe(
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
ep_size: int, dtype: torch.dtype, group_size: int,
has_zp: bool, weight_bits: int):
def test_fused_moe_wn16(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
group_size: int,
has_zp: bool,
weight_bits: int,
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
@@ -274,35 +299,40 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w1_ref = w1.clone()
w2_ref = w2.clone()
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
device="cuda",
dtype=torch.uint8)
w2_qweight = torch.empty((e, k, n // pack_factor),
device="cuda",
dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size),
device="cuda",
dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size),
device="cuda",
dtype=dtype)
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
device="cuda",
dtype=torch.uint8)
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
device="cuda",
dtype=torch.uint8)
w1_qweight = torch.empty(
(e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
)
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
w1_qzeros = torch.empty(
(e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
)
w2_qzeros = torch.empty(
(e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
)
for i in range(e * 2):
expert_id = i % e
if i // e == 0:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
w, w_ref, w_qweight, w_scales, w_qzeros = (
w1,
w1_ref,
w1_qweight,
w1_scales,
w1_qzeros,
)
else:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
w, w_ref, w_qweight, w_scales, w_qzeros = (
w2,
w2_ref,
w2_qweight,
w2_scales,
w2_qzeros,
)
weight, qweight, scales, qzeros = quantize_weights(
w[expert_id].T, quant_type, group_size, has_zp, False)
w[expert_id].T, quant_type, group_size, has_zp, False
)
weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T
@@ -321,11 +351,8 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device="cuda",
dtype=torch.int32)
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1_ref = w1_ref[e_ids]
w2_ref = w2_ref[e_ids]
@@ -344,28 +371,27 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
assert weight_bits == 8
quant_config_builder = int8_w8a16_moe_quant_config
quant_config = quant_config_builder(w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
quant_config = quant_config_builder(
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size],
)
with set_current_vllm_config(vllm_config):
triton_output = fused_moe(a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
global_num_experts=e,
expert_map=e_map,
quant_config=quant_config)
torch_output = torch_moe(a,
w1_ref,
w2_ref,
score,
topk,
expert_map=e_map)
triton_output = fused_moe(
a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
global_num_experts=e,
expert_map=e_map,
quant_config=quant_config,
)
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@@ -373,16 +399,20 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
@torch.inference_mode()
def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
use_rocm_aiter: bool, monkeypatch):
def test_mixtral_moe(
dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
# clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
is_rocm_aiter_moe_enabled,
)
is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
@@ -390,17 +420,16 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")
monkeypatch.setenv('RANK', "0")
monkeypatch.setenv('LOCAL_RANK', "0")
monkeypatch.setenv('WORLD_SIZE', "1")
monkeypatch.setenv('MASTER_ADDR', 'localhost')
monkeypatch.setenv('MASTER_PORT', '12345')
monkeypatch.setenv("RANK", "0")
monkeypatch.setenv("LOCAL_RANK", "0")
monkeypatch.setenv("WORLD_SIZE", "1")
monkeypatch.setenv("MASTER_ADDR", "localhost")
monkeypatch.setenv("MASTER_PORT", "12345")
init_distributed_environment()
# Instantiate our and huggingface's MoE blocks
vllm_config.compilation_config.static_forward_context = dict()
with (set_current_vllm_config(vllm_config),
set_forward_context(None, vllm_config)):
with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
vllm_moe = MixtralMoE(
@@ -416,27 +445,30 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
# Load the weights
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
weights = (
hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data,
)
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn(
(1, 64, config.hidden_size)).to(dtype).to("cuda")
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1)
# Pad the weight if moe padding is enabled
if padding:
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
vllm_moe.experts.w13_weight = Parameter(
F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[
..., 0:-128
],
requires_grad=False,
)
vllm_moe.experts.w2_weight = Parameter(
F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False,
)
torch.cuda.synchronize()
torch.cuda.empty_cache()
@@ -453,19 +485,21 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
if use_rocm_aiter:
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=0.01,
atol=100)
torch.testing.assert_close(
hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100
)
else:
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
torch.testing.assert_close(
hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype],
)
def marlin_moe_generate_valid_test_cases():
import itertools
m_list = [1, 123, 666]
n_list = [128, 1024]
k_list = [256, 2048]
@@ -484,16 +518,24 @@ def marlin_moe_generate_valid_test_cases():
]
is_k_full_list = [True, False]
all_combinations = itertools.product(m_list, n_list, k_list, e_list,
topk_list, ep_size_list, dtype_list,
group_size_list, act_order_list,
quant_type_list, is_k_full_list)
all_combinations = itertools.product(
m_list,
n_list,
k_list,
e_list,
topk_list,
ep_size_list,
dtype_list,
group_size_list,
act_order_list,
quant_type_list,
is_k_full_list,
)
def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order,
quant_type, is_k_full):
if quant_type == scalar_types.float8_e4m3fn and \
group_size not in [-1, 128]:
def is_invalid(
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
):
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
return False
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32]:
@@ -522,9 +564,10 @@ def marlin_moe_generate_valid_test_cases():
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
"act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases())
@pytest.mark.parametrize(
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
marlin_moe_generate_valid_test_cases(),
)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
@@ -549,7 +592,7 @@ def test_fused_marlin_moe(
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
@@ -567,11 +610,13 @@ def test_fused_marlin_moe(
for i in range(w1.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref1, qweight1, scales1, global_scale1 = \
w_ref1, qweight1, scales1, global_scale1 = (
rand_marlin_weight_nvfp4_like(w1[i], group_size)
)
else:
w_ref1, qweight1, scales1 = \
rand_marlin_weight_mxfp4_like(w1[i], group_size)
w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
w1[i], group_size
)
global_scale1 = None
w_ref1_l.append(w_ref1.T)
@@ -580,14 +625,14 @@ def test_fused_marlin_moe(
if global_scale1 is not None:
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
w1[i], group_size)
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
elif has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w1[i].transpose(1, 0), quant_type, group_size
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
@@ -595,9 +640,9 @@ def test_fused_marlin_moe(
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
@@ -624,11 +669,13 @@ def test_fused_marlin_moe(
for i in range(w2.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref2, qweight2, scales2, global_scale2 = \
w_ref2, qweight2, scales2, global_scale2 = (
rand_marlin_weight_nvfp4_like(w2[i], group_size)
)
else:
w_ref2, qweight2, scales2 = \
rand_marlin_weight_mxfp4_like(w2[i], group_size)
w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
w2[i], group_size
)
global_scale2 = None
w_ref2_l.append(w_ref2.T)
@@ -637,14 +684,14 @@ def test_fused_marlin_moe(
if global_scale2 is not None:
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
w2[i], group_size)
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
elif has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w2[i].transpose(1, 0), quant_type, group_size
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
@@ -652,9 +699,9 @@ def test_fused_marlin_moe(
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
@@ -675,12 +722,7 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a,
w_ref1,
w_ref2,
score,
topk,
expert_map=e_map)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
@@ -704,7 +746,8 @@ def test_fused_marlin_moe(
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
is_k_full=is_k_full,
)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@@ -738,9 +781,9 @@ def test_fused_marlin_moe_with_bias(m):
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
marlin_quantize(w1[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
@@ -767,9 +810,9 @@ def test_fused_marlin_moe_with_bias(m):
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
marlin_quantize(w2[i].transpose(1, 0), quant_type,
group_size, act_order, test_perm)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
@@ -792,8 +835,7 @@ def test_fused_marlin_moe_with_bias(m):
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
b_bias2)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
@@ -817,7 +859,8 @@ def test_fused_marlin_moe_with_bias(m):
w1_zeros=zeros1,
w2_zeros=zeros2,
quant_type_id=quant_type.id,
is_k_full=is_k_full)
is_k_full=is_k_full,
)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@@ -825,34 +868,36 @@ def test_fused_marlin_moe_with_bias(m):
def test_moe_align_block_size_opcheck():
num_experts = 4
block_size = 4
topk_ids = torch.randint(0,
num_experts, (3, 4),
dtype=torch.int32,
device='cuda')
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
opcheck(torch.ops._moe_C.moe_align_block_size,
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
num_tokens_post_pad))
opcheck(
torch.ops._moe_C.moe_align_block_size,
(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
),
)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)

View File

@@ -11,7 +11,8 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
moe_align_block_size,
)
from vllm.platforms import current_platform
from vllm.utils import round_up
@@ -60,30 +61,33 @@ def _verify_expert_level_sorting(
in topk_ids in the final sorted_ids however this does not impact quality.
"""
# Group tokens by expert from the golden implementation
golden_expert_tokens = _group_tokens_by_expert(golden_sorted_ids,
expert_ids, block_size,
valid_length, total_tokens)
golden_expert_tokens = _group_tokens_by_expert(
golden_sorted_ids, expert_ids, block_size, valid_length, total_tokens
)
actual_expert_tokens = _group_tokens_by_expert(actual_sorted_ids,
expert_ids, block_size,
valid_length, total_tokens)
actual_expert_tokens = _group_tokens_by_expert(
actual_sorted_ids, expert_ids, block_size, valid_length, total_tokens
)
assert set(golden_expert_tokens.keys()) == set(
actual_expert_tokens.keys()), (
f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, "
f"actual={set(actual_expert_tokens.keys())}")
assert set(golden_expert_tokens.keys()) == set(actual_expert_tokens.keys()), (
f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, "
f"actual={set(actual_expert_tokens.keys())}"
)
for expert_id in golden_expert_tokens:
golden_tokens = torch.tensor(golden_expert_tokens[expert_id],
device=actual_sorted_ids.device)
actual_tokens = torch.tensor(actual_expert_tokens[expert_id],
device=actual_sorted_ids.device)
golden_tokens = torch.tensor(
golden_expert_tokens[expert_id], device=actual_sorted_ids.device
)
actual_tokens = torch.tensor(
actual_expert_tokens[expert_id], device=actual_sorted_ids.device
)
assert torch.equal(
torch.sort(golden_tokens)[0],
torch.sort(actual_tokens)[0]), (
f"Expert {expert_id} token mismatch: "
f"golden={golden_expert_tokens[expert_id]}, "
f"actual={actual_expert_tokens[expert_id]}")
torch.sort(golden_tokens)[0], torch.sort(actual_tokens)[0]
), (
f"Expert {expert_id} token mismatch: "
f"golden={golden_expert_tokens[expert_id]}, "
f"actual={actual_expert_tokens[expert_id]}"
)
def torch_moe_align_block_size(
@@ -104,40 +108,38 @@ def torch_moe_align_block_size(
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
flattened_token_indices = torch.arange(topk_ids.numel(),
device=topk_ids.device,
dtype=torch.int32)
flattened_token_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.int32
)
flattened_expert_ids = topk_ids.flatten()
sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids,
stable=True)
sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, stable=True)
sorted_token_indices = flattened_token_indices[sort_indices]
expert_token_counts = torch.zeros(num_experts,
dtype=torch.int64,
device=topk_ids.device)
expert_token_counts = torch.zeros(
num_experts, dtype=torch.int64, device=topk_ids.device
)
for expert_id in range(num_experts):
mask = sorted_expert_ids == expert_id
expert_token_counts[expert_id] = mask.sum()
expert_padded_counts = torch.zeros(num_experts,
dtype=torch.int64,
device=topk_ids.device)
expert_padded_counts = torch.zeros(
num_experts, dtype=torch.int64, device=topk_ids.device
)
for expert_id in range(num_experts):
original_count = expert_token_counts[expert_id]
if original_count > 0:
expert_padded_counts[expert_id] = (
(original_count + block_size - 1) // block_size) * block_size
(original_count + block_size - 1) // block_size
) * block_size
sorted_token_ids = torch.full(
(max_num_tokens_padded, ),
(max_num_tokens_padded,),
topk_ids.numel(),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size
expert_ids = torch.zeros(max_num_blocks,
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device)
current_pos = 0
current_block = 0
@@ -147,20 +149,20 @@ def torch_moe_align_block_size(
num_expert_tokens = expert_tokens.shape[0]
if num_expert_tokens > 0:
sorted_token_ids[current_pos:current_pos +
num_expert_tokens] = (expert_tokens)
sorted_token_ids[current_pos : current_pos + num_expert_tokens] = (
expert_tokens
)
expert_blocks_needed = expert_padded_counts[expert_id] // block_size
expert_ids[current_block:current_block +
expert_blocks_needed] = (expert_id)
expert_ids[current_block : current_block + expert_blocks_needed] = expert_id
current_pos += expert_padded_counts[expert_id]
current_block += expert_blocks_needed
total_padded_tokens = expert_padded_counts.sum()
num_tokens_post_pad = torch.tensor([total_padded_tokens],
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.tensor(
[total_padded_tokens], dtype=torch.int32, device=topk_ids.device
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
@@ -173,37 +175,32 @@ def torch_moe_align_block_size(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size(m: int, topk: int, num_experts: int,
block_size: int, pad_sorted_ids: bool):
def test_moe_align_block_size(
m: int, topk: int, num_experts: int, block_size: int, pad_sorted_ids: bool
):
"""Test moe_align_block_size without expert mapping"""
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
for i in range(m):
experts = torch.randperm(num_experts, device="cuda")[:topk]
topk_ids[i] = experts
actual_sorted_ids, actual_expert_ids, actual_num_tokens = (
moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
pad_sorted_ids=pad_sorted_ids,
))
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
pad_sorted_ids=pad_sorted_ids,
)
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
torch_moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
pad_sorted_ids=pad_sorted_ids,
))
)
)
torch.testing.assert_close(actual_num_tokens,
golden_num_tokens,
atol=0,
rtol=0)
torch.testing.assert_close(actual_expert_ids,
golden_expert_ids,
atol=0,
rtol=0)
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
# For sorted_token_ids, verify block-level correctness rather than exact
# order Tokens within each expert's blocks can be in any order, but expert
@@ -219,16 +216,18 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
total_tokens = m * topk
assert actual_num_tokens.item() % block_size == 0, (
"num_tokens_post_pad should be divisible by block_size")
"num_tokens_post_pad should be divisible by block_size"
)
assert actual_num_tokens.item() >= total_tokens, (
"num_tokens_post_pad should be at least total_tokens")
"num_tokens_post_pad should be at least total_tokens"
)
valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens]
assert len(valid_tokens) == total_tokens, (
f"Should have exactly {total_tokens} valid tokens, "
f"got {len(valid_tokens)}")
assert (actual_expert_ids >= 0).all() and (
actual_expert_ids
< num_experts).all(), "expert_ids should contain valid expert indices"
f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}"
)
assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), (
"expert_ids should contain valid expert indices"
)
@pytest.mark.parametrize("m", [16, 32])
@@ -236,46 +235,37 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size_with_expert_map(m: int, topk: int,
num_experts: int,
block_size: int):
def test_moe_align_block_size_with_expert_map(
m: int, topk: int, num_experts: int, block_size: int
):
"""Test moe_align_block_size with expert mapping (EP scenario)"""
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
for i in range(m):
experts = torch.randperm(num_experts, device="cuda")[:topk]
topk_ids[i] = experts
expert_map = torch.full((num_experts, ),
-1,
device="cuda",
dtype=torch.int32)
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
local_experts = list(range(0, num_experts, 2))
for i, expert_id in enumerate(local_experts):
expert_map[expert_id] = i
actual_sorted_ids, actual_expert_ids, actual_num_tokens = (
moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
expert_map=expert_map,
))
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
expert_map=expert_map,
)
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
torch_moe_align_block_size(
topk_ids=topk_ids,
block_size=block_size,
num_experts=num_experts,
expert_map=expert_map,
))
)
)
torch.testing.assert_close(actual_num_tokens,
golden_num_tokens,
atol=0,
rtol=0)
torch.testing.assert_close(actual_expert_ids,
golden_expert_ids,
atol=0,
rtol=0)
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
_verify_expert_level_sorting(
actual_sorted_ids,
golden_sorted_ids,
@@ -290,26 +280,25 @@ def test_moe_align_block_size_deterministic():
m, topk, num_experts, block_size = 128, 2, 32, 64
torch.manual_seed(42)
topk_ids = torch.randint(0,
num_experts, (m, topk),
device="cuda",
dtype=torch.int32)
topk_ids = torch.randint(
0, num_experts, (m, topk), device="cuda", dtype=torch.int32
)
# expect the results to be reproducible
results = []
for _ in range(5):
sorted_ids, expert_ids, num_tokens = moe_align_block_size(
topk_ids=topk_ids, block_size=block_size, num_experts=num_experts)
results.append(
(sorted_ids.clone(), expert_ids.clone(), num_tokens.clone()))
topk_ids=topk_ids, block_size=block_size, num_experts=num_experts
)
results.append((sorted_ids.clone(), expert_ids.clone(), num_tokens.clone()))
for i in range(1, len(results)):
assert torch.equal(
results[0][0],
results[i][0]), ("sorted_ids should be deterministic")
assert torch.equal(
results[0][1],
results[i][1]), ("expert_ids should be deterministic")
assert torch.equal(
results[0][2],
results[i][2]), ("num_tokens should be deterministic")
assert torch.equal(results[0][0], results[i][0]), (
"sorted_ids should be deterministic"
)
assert torch.equal(results[0][1], results[i][1]), (
"expert_ids should be deterministic"
)
assert torch.equal(results[0][2], results[i][2]), (
"num_tokens should be deterministic"
)

View File

@@ -14,7 +14,10 @@ import torch
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_permute_unpermute_supported, moe_unpermute)
moe_permute,
moe_permute_unpermute_supported,
moe_unpermute,
)
from vllm.platforms import current_platform
NUM_EXPERTS = [16, 64, 256]
@@ -24,35 +27,34 @@ current_platform.seed_everything(0)
def torch_permute(
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
# token_expert_indices: torch.Tensor,
topk: int,
n_expert: int,
n_local_expert: int,
start_expert: int,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
# token_expert_indices: torch.Tensor,
topk: int,
n_expert: int,
n_local_expert: int,
start_expert: int,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1,
) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
if expert_map is not None:
is_local_expert = (expert_map[topk_ids] != -1)
not_local_expert = (expert_map[topk_ids] == -1)
topk_ids = is_local_expert * (
topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert)
token_expert_indices = torch.arange(0,
n_token * topk,
dtype=torch.int32,
device=hidden_states.device).reshape(
(n_token, topk))
is_local_expert = expert_map[topk_ids] != -1
not_local_expert = expert_map[topk_ids] == -1
topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * (
topk_ids + n_expert
)
token_expert_indices = torch.arange(
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
).reshape((n_token, topk))
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(),
stable=True)
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True)
dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices]
expert_first_token_offset = torch.zeros(n_local_expert + 1,
dtype=torch.int64,
device="cuda")
expert_first_token_offset = torch.zeros(
n_local_expert + 1, dtype=torch.int64, device="cuda"
)
idx = 0
for i in range(0, n_local_expert):
cnt = 0
@@ -64,116 +66,133 @@ def torch_permute(
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
valid_row_idx = []
if align_block_size is None:
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map //
topk, ...]
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
permuted_row_size = permuted_hidden_states.shape[0]
m_indices = torch.empty(permuted_row_size,
device="cuda",
dtype=torch.int32).fill_(fill_invalid_expert)
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
m_indices[first_token_offset:last_token_offset] = i - 1
src_row_id2dst_row_id_map = torch.arange(
0, n_token * topk, device="cuda",
dtype=torch.int32)[src2dst_idx].reshape((n_token, topk))
0, n_token * topk, device="cuda", dtype=torch.int32
)[src2dst_idx].reshape((n_token, topk))
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
dst_row_id2src_row_id_map[
expert_first_token_offset[-1]:] = n_token * topk
dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
return [
permuted_hidden_states, expert_first_token_offset,
src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices,
valid_row_idx
permuted_hidden_states,
expert_first_token_offset,
src_row_id2dst_row_id_map,
dst_row_id2src_row_id_map,
m_indices,
valid_row_idx,
]
else:
permuted_row_size = (topk * n_token + n_expert *
(align_block_size - 1) + align_block_size -
1) // align_block_size * align_block_size
permuted_idx = torch.full((permuted_row_size, ),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device)
permuted_hidden_states = torch.empty((permuted_row_size, n_hidden),
device="cuda",
dtype=hidden_states.dtype)
align_src_row_id2dst_row_id = torch.empty(n_token * topk,
device="cuda",
dtype=torch.int32)
align_expert_first_token_offset = torch.zeros_like(
expert_first_token_offset)
m_indices = torch.empty(permuted_row_size,
device="cuda",
dtype=torch.int32).fill_(fill_invalid_expert)
permuted_row_size = (
(topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
// align_block_size
* align_block_size
)
permuted_idx = torch.full(
(permuted_row_size,),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device,
)
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
)
align_src_row_id2dst_row_id = torch.empty(
n_token * topk, device="cuda", dtype=torch.int32
)
align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
n_token_in_expert = last_token_offset - first_token_offset
align_expert_first_token_offset[
i] = align_expert_first_token_offset[
i - 1] + (n_token_in_expert + align_block_size -
1) // align_block_size * align_block_size
align_expert_first_token_offset[i] = (
align_expert_first_token_offset[i - 1]
+ (n_token_in_expert + align_block_size - 1)
// align_block_size
* align_block_size
)
align_first_token_offset = align_expert_first_token_offset[i - 1]
align_last_token_offset = align_expert_first_token_offset[i]
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
first_token_offset:first_token_offset + n_token_in_expert]
first_token_offset : first_token_offset + n_token_in_expert
]
# store token in current expert with align_first_token_offset
permuted_hidden_states[align_first_token_offset:\
align_first_token_offset+n_token_in_expert,\
...] = hidden_states[\
dst_row_id2src_row_id_in_expert // topk,\
...]
permuted_idx[align_first_token_offset:\
align_first_token_offset+\
n_token_in_expert] = dst_row_id2src_row_id_in_expert
permuted_hidden_states[
align_first_token_offset : align_first_token_offset + n_token_in_expert,
...,
] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
permuted_idx[
align_first_token_offset : align_first_token_offset + n_token_in_expert
] = dst_row_id2src_row_id_in_expert
# set current expert m_indices
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
valid_row_idx += [
i for i in range(align_first_token_offset,
align_first_token_offset + n_token_in_expert)
i
for i in range(
align_first_token_offset,
align_first_token_offset + n_token_in_expert,
)
]
# get align_src_row_id2dst_row_id
for i in range(n_token * topk):
eid = sorted_topk_ids[i]
if (eid >= n_local_expert):
if eid >= n_local_expert:
# check token not in local expert
align_src_row_id2dst_row_id[
i] = align_expert_first_token_offset[-1]
align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
continue
first_token_offset = expert_first_token_offset[eid]
align_first_token_offset = align_expert_first_token_offset[eid]
token_offset = i - first_token_offset
align_src_row_id2dst_row_id[
i] = align_first_token_offset + token_offset
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\
src2dst_idx].reshape((n_token, topk))
align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
(n_token, topk)
)
return [
permuted_hidden_states, align_expert_first_token_offset,
align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx
permuted_hidden_states,
align_expert_first_token_offset,
align_src_row_id2dst_row_id,
permuted_idx,
m_indices,
valid_row_idx,
]
def torch_unpermute(permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
src_row_id2dst_row_id_map: torch.Tensor,
valid_row_idx: torch.Tensor, topk: int,
n_expert: int) -> torch.Tensor:
def torch_unpermute(
permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
src_row_id2dst_row_id_map: torch.Tensor,
valid_row_idx: torch.Tensor,
topk: int,
n_expert: int,
) -> torch.Tensor:
# ignore invalid row
n_hidden = permuted_hidden_states.shape[1]
mask = torch.zeros(permuted_hidden_states.shape[0],
dtype=bool,
device="cuda")
mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda")
mask[valid_row_idx] = True
permuted_hidden_states[~mask] = 0
permuted_hidden_states = permuted_hidden_states[
src_row_id2dst_row_id_map.flatten(), ...]
src_row_id2dst_row_id_map.flatten(), ...
]
permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden)
output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to(
permuted_hidden_states.dtype)
output = (
(permuted_hidden_states * topk_weights.unsqueeze(2))
.sum(1)
.to(permuted_hidden_states.dtype)
)
return output
@@ -184,59 +203,76 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("align_block_size", [None, 128])
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
n_expert: int, ep_size: int, dtype: torch.dtype,
align_block_size: Optional[int]):
def test_moe_permute_unpermute(
n_token: int,
n_hidden: int,
topk: int,
n_expert: int,
ep_size: int,
dtype: torch.dtype,
align_block_size: Optional[int],
):
if not moe_permute_unpermute_supported():
pytest.skip("moe_permute_unpermute is not supported on this platform.")
fill_invalid_expert = 0
ep_rank = np.random.randint(0, ep_size)
expert_map = None
n_local_expert = n_expert
if (ep_size != 1):
n_local_expert, expert_map = determine_expert_map(
ep_size, ep_rank, n_expert)
if ep_size != 1:
n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert)
expert_map = expert_map.cuda()
start_expert = n_local_expert * ep_rank
current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, False)
(gold_permuted_hidden_states, gold_expert_first_token_offset,
gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices,
valid_row_idx) = torch_permute(
hidden_states,
topk_ids,
# token_expert_indices,
topk,
n_expert,
n_local_expert,
start_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
hidden_states, gating_output, topk, False
)
(
gold_permuted_hidden_states,
gold_expert_first_token_offset,
gold_inv_permuted_idx,
gold_permuted_idx,
gold_m_indices,
valid_row_idx,
) = torch_permute(
hidden_states,
topk_ids,
# token_expert_indices,
topk,
n_expert,
n_local_expert,
start_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
)
(permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx,
m_indices) = moe_permute(hidden_states=hidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=n_expert,
n_local_expert=n_local_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
(
permuted_hidden_states,
_,
expert_first_token_offset,
inv_permuted_idx,
m_indices,
) = moe_permute(
hidden_states=hidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=n_expert,
n_local_expert=n_local_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
)
# check expert_first_token_offset
torch.testing.assert_close(gold_expert_first_token_offset,
expert_first_token_offset,
atol=0,
rtol=0)
torch.testing.assert_close(
gold_expert_first_token_offset, expert_first_token_offset, atol=0, rtol=0
)
# check src_row_id2dst_row_id_map
torch.testing.assert_close(gold_inv_permuted_idx.flatten(),
inv_permuted_idx,
atol=0,
rtol=0)
torch.testing.assert_close(
gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
)
# check mindice
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
@@ -244,19 +280,28 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx],
atol=0,
rtol=0)
torch.testing.assert_close(
gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx],
atol=0,
rtol=0,
)
# add a random tensor to simulate group gemm
result0 = 0.5 * permuted_hidden_states + torch.randn_like(
permuted_hidden_states)
result0 = 0.5 * permuted_hidden_states + torch.randn_like(permuted_hidden_states)
result4 = torch.empty_like(hidden_states)
moe_unpermute(result4, result0, topk_weights, inv_permuted_idx,
expert_first_token_offset)
moe_unpermute(
result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset
)
gold4 = torch_unpermute(result0, topk_weights, topk_ids,
token_expert_indices, inv_permuted_idx,
valid_row_idx, topk, n_local_expert)
gold4 = torch_unpermute(
result0,
topk_weights,
topk_ids,
token_expert_indices,
inv_permuted_idx,
valid_row_idx,
topk,
n_local_expert,
)
# check unpermuted hidden
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)

View File

@@ -11,27 +11,39 @@ import torch
from packaging import version
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW4A4MXFP4)
QuarkLinearMethod,
QuarkW4A4MXFP4,
)
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkW4A4MXFp4MoEMethod)
QuarkW4A4MXFp4MoEMethod,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
) and current_platform.is_device_capability(100)
TRTLLM_GEN_MXFP4_AVAILABLE = (
current_platform.is_cuda() and current_platform.is_device_capability(100)
)
HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and has_flashinfer())
HOPPER_MXFP4_BF16_AVAILABLE = (
current_platform.is_cuda()
and current_platform.is_device_capability(90)
and has_flashinfer()
)
if TRTLLM_GEN_MXFP4_AVAILABLE:
from flashinfer import (fp4_quantize, mxfp8_quantize,
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
from flashinfer import (
fp4_quantize,
mxfp8_quantize,
next_positive_power_of_2,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
trtllm_fp4_block_scale_moe,
)
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
@@ -48,21 +60,25 @@ def enable_pickle(monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.mark.parametrize('model_case', [
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1)
])
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize(
"model_case",
[
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
],
)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
if torch.cuda.device_count() < model_case.tp:
pytest.skip(f"This test requires >={model_case.tp} gpus, got only "
f"{torch.cuda.device_count()}")
pytest.skip(
f"This test requires >={model_case.tp} gpus, got only "
f"{torch.cuda.device_count()}"
)
with vllm_runner(model_case.model_id,
tensor_parallel_size=model_case.tp,
load_format="dummy") as llm:
with vllm_runner(
model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy"
) as llm:
def check_model(model):
layer = model.model.layers[0]
@@ -72,21 +88,16 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
assert isinstance(layer.mlp.experts.quant_method,
QuarkW4A4MXFp4MoEMethod)
assert isinstance(layer.mlp.experts.quant_method, QuarkW4A4MXFp4MoEMethod)
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
llm.apply_model(check_model)
output = llm.generate_greedy("Today I am in the French Alps and",
max_tokens=20)
output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
assert output
def swiglu(x,
alpha: float = 1.702,
beta: float = 1.0,
limit: Optional[float] = None):
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None):
# Note we add an extra bias of 1 to the linear layer
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
if limit is not None:
@@ -96,24 +107,19 @@ def swiglu(x,
return out_glu * (x_linear + beta)
fp4_lookup_table = [
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6
]
fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6]
def mxfp4_dequantize(x, scale):
assert x.dtype == torch.uint8
x = x.view(torch.uint8).to(torch.int32)
x_unpacked = torch.zeros(*x.shape[:-1],
x.shape[-1] * 2,
dtype=torch.int32,
device=x.device)
x_unpacked = torch.zeros(
*x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device
)
x_unpacked[..., 0::2].copy_(x & 0xF)
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
x_float = torch.zeros(x_unpacked.shape,
dtype=torch.float32,
device=x.device)
x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device)
for i, val in enumerate(fp4_lookup_table):
x_float[x_unpacked == i] = val
@@ -162,9 +168,10 @@ def reference_moe(
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
if act_type == 'mxfp8':
t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16),
is_sf_swizzled_layout=False)
if act_type == "mxfp8":
t_quantized, t_scale = mxfp8_quantize(
t.to(torch.bfloat16), is_sf_swizzled_layout=False
)
t = mxfp8_dequantize(t_quantized, t_scale)
# MLP #2
mlp2_weight = w2[expert_indices, ...]
@@ -221,37 +228,53 @@ def tg_mxfp4_moe(
transpose_optimized: bool = False,
) -> torch.Tensor:
sf_block_size = 32
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
and w13_weight.shape[1] == intermediate_size * 2
and w13_weight.shape[2] == hidden_size // 2)
assert (w13_weight_scale.dim() == 3
and w13_weight_scale.shape[0] == num_experts
and w13_weight_scale.shape[1] == intermediate_size * 2
and w13_weight_scale.shape[2] == hidden_size // sf_block_size)
assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts
and w2_weight.shape[1] == hidden_size
and w2_weight.shape[2] == intermediate_size // 2)
assert (w2_weight_scale.dim() == 3
and w2_weight_scale.shape[1] == hidden_size
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size)
assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts
and w13_bias.shape[1] == intermediate_size * 2)
assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts
and w2_bias.shape[1] == hidden_size)
assert (
w13_weight.dim() == 3
and w13_weight.shape[0] == num_experts
and w13_weight.shape[1] == intermediate_size * 2
and w13_weight.shape[2] == hidden_size // 2
)
assert (
w13_weight_scale.dim() == 3
and w13_weight_scale.shape[0] == num_experts
and w13_weight_scale.shape[1] == intermediate_size * 2
and w13_weight_scale.shape[2] == hidden_size // sf_block_size
)
assert (
w2_weight.dim() == 3
and w2_weight.shape[0] == num_experts
and w2_weight.shape[1] == hidden_size
and w2_weight.shape[2] == intermediate_size // 2
)
assert (
w2_weight_scale.dim() == 3
and w2_weight_scale.shape[1] == hidden_size
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size
)
assert (
w13_bias.dim() == 2
and w13_bias.shape[0] == num_experts
and w13_bias.shape[1] == intermediate_size * 2
)
assert (
w2_bias.dim() == 2
and w2_bias.shape[0] == num_experts
and w2_bias.shape[1] == hidden_size
)
# Swap w1 and w3 as the definition of
# swiglu is different in the trtllm-gen
w13_weight_scale_ = w13_weight_scale.clone()
w13_weight_ = w13_weight.clone()
w13_bias_ = w13_bias.clone()
w13_weight[:, :intermediate_size, :].copy_(
w13_weight_[:, intermediate_size:, :])
w13_weight[:, intermediate_size:, :].copy_(
w13_weight_[:, :intermediate_size, :])
w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :])
w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :])
w13_weight_scale[:, :intermediate_size, :].copy_(
w13_weight_scale_[:, intermediate_size:, :])
w13_weight_scale_[:, intermediate_size:, :]
)
w13_weight_scale[:, intermediate_size:, :].copy_(
w13_weight_scale_[:, :intermediate_size, :])
w13_weight_scale_[:, :intermediate_size, :]
)
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
@@ -261,18 +284,23 @@ def tg_mxfp4_moe(
w13_bias_interleaved = []
for i in range(num_experts):
w13_weight_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_weight[i].clone()))
reorder_rows_for_gated_act_gemm(w13_weight[i].clone())
)
w13_weight_scale_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()))
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())
)
w13_bias_interleaved.append(
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1,
1)))
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1))
)
w13_weight = torch.stack(w13_weight_interleaved).reshape(
num_experts, 2 * intermediate_size, hidden_size // 2)
num_experts, 2 * intermediate_size, hidden_size // 2
)
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
num_experts, 2 * intermediate_size, hidden_size // 32)
num_experts, 2 * intermediate_size, hidden_size // 32
)
w13_bias = torch.stack(w13_bias_interleaved).reshape(
num_experts, 2 * intermediate_size)
num_experts, 2 * intermediate_size
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_shuffled = []
@@ -291,9 +319,11 @@ def tg_mxfp4_moe(
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_shuffled.append(w13_weight[i].view(
torch.uint8)[permute_indices.to(
w13_weight.device)].contiguous())
gemm1_weights_shuffled.append(
w13_weight[i]
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
.contiguous()
)
# w13 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
@@ -302,26 +332,35 @@ def tg_mxfp4_moe(
num_elts_per_sf=16,
)
gemm1_scales_shuffled.append(
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w13_weight_scale.device)].contiguous()))
nvfp4_block_scale_interleave(
w13_weight_scale[i]
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
.contiguous()
)
)
# w13 bias shuffling
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
-1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous())
gemm1_bias_shuffled.append(
w13_bias[i]
.clone()
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
.contiguous()
)
# w2 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_shuffled.append(w2_weight[i].view(
torch.uint8)[permute_indices.to(
w2_weight.device)].contiguous())
gemm2_weights_shuffled.append(
w2_weight[i]
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
.contiguous()
)
# w2 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
@@ -330,48 +369,65 @@ def tg_mxfp4_moe(
num_elts_per_sf=16,
)
gemm2_scales_shuffled.append(
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
torch.uint8)[permute_sf_indices.to(
w2_weight_scale.device)].contiguous()))
nvfp4_block_scale_interleave(
w2_weight_scale[i]
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
.contiguous()
)
)
# w2 bias shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
_cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
gemm2_bias_shuffled.append(
w2_bias[i]
.clone()
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
.contiguous()
)
else:
for i in range(num_experts):
gemm1_weights_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
epilogue_tile_m))
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
)
gemm1_scales_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
shuffle_matrix_sf_a(
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
)
)
gemm2_weights_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
epilogue_tile_m))
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
)
gemm2_scales_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
shuffle_matrix_sf_a(
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
)
)
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)
)
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)
)
w13_weight = torch.stack(gemm1_weights_shuffled)
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
num_experts, 2 * intermediate_size,
hidden_size // sf_block_size).view(torch.float8_e4m3fn)
w13_weight_scale = (
torch.stack(gemm1_scales_shuffled)
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
w2_weight = torch.stack(gemm2_weights_shuffled)
w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape(
num_experts, hidden_size,
intermediate_size // sf_block_size).view(torch.float8_e4m3fn)
w2_weight_scale = (
torch.stack(gemm2_scales_shuffled)
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
.view(torch.float8_e4m3fn)
)
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
tg_result = trtllm_fp4_block_scale_moe(
@@ -401,7 +457,8 @@ def tg_mxfp4_moe(
routed_scaling_factor=None,
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
routing_method_type=1, # renormalize
do_finalize=True)[0]
do_finalize=True,
)[0]
return tg_result
@@ -424,20 +481,21 @@ def check_accuracy(a, b, atol, rtol, percent):
if mismatch_percent > 1 - percent:
raise Exception(
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
f"(threshold: {1-percent:.4f})")
f"(threshold: {1 - percent:.4f})"
)
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"])
@pytest.mark.parametrize("transpose_optimized", [False, True])
@pytest.mark.skipif(
not TRTLLM_GEN_MXFP4_AVAILABLE,
reason="nvidia gpu and compute capability sm100 is required for this test")
reason="nvidia gpu and compute capability sm100 is required for this test",
)
def test_trtllm_gen_mxfp4_fused_moe(
topk: int,
num_experts: int,
@@ -452,45 +510,52 @@ def test_trtllm_gen_mxfp4_fused_moe(
):
seed = 42
torch.manual_seed(seed)
hidden_states = torch.randn(num_tokens,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16)
w13 = (torch.randn(num_experts,
intermediate_size * 2,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16))
w2 = (torch.randn(num_experts,
hidden_size,
intermediate_size,
device="cuda:0",
dtype=torch.bfloat16))
bias13 = torch.randn(num_experts, intermediate_size * 2,
device="cuda:0") * 10
hidden_states = torch.randn(
num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16
)
w13 = torch.randn(
num_experts,
intermediate_size * 2,
hidden_size,
device="cuda:0",
dtype=torch.bfloat16,
)
w2 = torch.randn(
num_experts,
hidden_size,
intermediate_size,
device="cuda:0",
dtype=torch.bfloat16,
)
bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
router_logits = torch.rand(num_tokens, num_experts,
dtype=torch.float32).cuda()
router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda()
w13, w13_scale = fp4_quantize(w13,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False)
w13, w13_scale = fp4_quantize(
w13,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False,
)
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
num_experts, intermediate_size * 2, hidden_size // 32)
w2, w2_scale = fp4_quantize(w2,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False)
num_experts, intermediate_size * 2, hidden_size // 32
)
w2, w2_scale = fp4_quantize(
w2,
torch.tensor(1.0, device="cuda:0"),
32,
sf_use_ue8m0=True,
is_sf_swizzled_layout=False,
)
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 32)
if act_type == 'mxfp8':
num_experts, hidden_size, intermediate_size // 32
)
if act_type == "mxfp8":
hidden_states, hidden_states_scale = mxfp8_quantize(
hidden_states, is_sf_swizzled_layout=False)
hidden_states_scale = hidden_states_scale.view(
torch.float8_e4m3fn).reshape(-1)
hidden_states, is_sf_swizzled_layout=False
)
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1)
else:
hidden_states_scale = None
@@ -500,9 +565,10 @@ def test_trtllm_gen_mxfp4_fused_moe(
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
bias13_ref = bias13
bias2_ref = bias2
if act_type == 'mxfp8':
hidden_states_ref = mxfp8_dequantize(
hidden_states, hidden_states_scale).to(torch.float32)
if act_type == "mxfp8":
hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to(
torch.float32
)
else:
hidden_states_ref = hidden_states.to(torch.float32)
# Process tokens in chunks of 32 to reduce memory usage
@@ -529,29 +595,31 @@ def test_trtllm_gen_mxfp4_fused_moe(
# trtllm-gen result
if alpha is not None:
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
if limit is not None:
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
limit = torch.full((num_experts,), limit, device=hidden_states.device)
if beta is not None:
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
tg_result = tg_mxfp4_moe(router_logits,
topk,
num_experts,
intermediate_size,
hidden_size,
hidden_states,
hidden_states_scale,
w13,
w13_scale,
bias13,
w2,
w2_scale,
bias2,
act_type,
alpha=alpha,
beta=beta,
limit=limit,
transpose_optimized=transpose_optimized)
beta = torch.full((num_experts,), beta, device=hidden_states.device)
tg_result = tg_mxfp4_moe(
router_logits,
topk,
num_experts,
intermediate_size,
hidden_size,
hidden_states,
hidden_states_scale,
w13,
w13_scale,
bias13,
w2,
w2_scale,
bias2,
act_type,
alpha=alpha,
beta=beta,
limit=limit,
transpose_optimized=transpose_optimized,
)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
@@ -573,8 +641,7 @@ def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
@pytest.mark.skipif(
not HOPPER_MXFP4_BF16_AVAILABLE,
reason="nvidia gpu sm90 and flashinfer are required for this test",
@@ -593,52 +660,73 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
device = "cuda:0"
# Inputs
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=torch.bfloat16)
hidden_states = torch.randn(
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
)
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
w13_q = torch.randint(
0,
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
256,
(num_experts, 2 * intermediate_size, hidden_size // 2),
device=device,
dtype=torch.uint8)
dtype=torch.uint8,
)
w13_scale = torch.randint(
118,
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
123,
(num_experts, 2 * intermediate_size, hidden_size // 32),
device=device,
dtype=torch.uint8)
dtype=torch.uint8,
)
w2_q = torch.randint(0,
256,
(num_experts, hidden_size, intermediate_size // 2),
device=device,
dtype=torch.uint8)
w2_q = torch.randint(
0,
256,
(num_experts, hidden_size, intermediate_size // 2),
device=device,
dtype=torch.uint8,
)
w2_scale = torch.randint(
118,
123, (num_experts, hidden_size, intermediate_size // 32),
123,
(num_experts, hidden_size, intermediate_size // 32),
device=device,
dtype=torch.uint8)
dtype=torch.uint8,
)
# Bias contiguous [b1; b3]
bias13 = (torch.randn(num_experts,
2 * intermediate_size,
device=device,
dtype=torch.bfloat16) * 10)
bias2 = (torch.randn(
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
router_logits = torch.rand(num_tokens,
num_experts,
dtype=torch.float32,
device=device)
bias13 = (
torch.randn(
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
)
* 10
)
bias2 = (
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
)
router_logits = torch.rand(
num_tokens, num_experts, dtype=torch.float32, device=device
)
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
num_experts, 2 * intermediate_size, hidden_size)
num_experts, 2 * intermediate_size, hidden_size
)
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
num_experts, hidden_size, intermediate_size)
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
hidden_states.to(torch.float32), w13_ref,
bias13.to(torch.float32), w2_ref,
bias2.to(torch.float32), alpha, beta, limit, 'bf16')
num_experts, hidden_size, intermediate_size
)
ref = reference_moe(
router_logits.to(torch.float32),
topk,
num_experts,
hidden_states.to(torch.float32),
w13_ref,
bias13.to(torch.float32),
w2_ref,
bias2.to(torch.float32),
alpha,
beta,
limit,
"bf16",
)
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
@@ -654,23 +742,24 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
routing_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
token_final_scales, token_selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
token_final_scales = (token_final_scales /
token_final_scales.sum(dim=-1, keepdim=True))
routing_weights = torch.nn.functional.softmax(
router_logits, dim=1, dtype=torch.float32
)
token_final_scales, token_selected_experts = torch.topk(
routing_weights, topk, dim=-1
)
token_final_scales = token_final_scales / token_final_scales.sum(
dim=-1, keepdim=True
)
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
if alpha is not None:
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
if beta is not None:
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
beta = torch.full((num_experts,), beta, device=hidden_states.device)
if limit is not None:
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
limit = torch.full((num_experts,), limit, device=hidden_states.device)
_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
@@ -680,8 +769,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
fc2_expert_weights=w2_q,
output_dtype=torch.bfloat16,
output=out,
quant_scales=[w13_s_inter.to(torch.uint8),
w2_s_inter.to(torch.uint8)],
quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)],
fc1_expert_biases=w13_b,
fc2_expert_biases=bias2.to(torch.bfloat16),
swiglu_alpha=alpha,
@@ -702,11 +790,13 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("num_tokens", [1, 128])
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
(1.702, 1.0, 7.0)])
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
@pytest.mark.skipif(
not (current_platform.is_cuda()
and current_platform.is_device_capability(100) and has_flashinfer()),
not (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and has_flashinfer()
),
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
)
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
@@ -723,32 +813,43 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
device = "cuda:0"
# Inputs
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=torch.bfloat16)
hidden_states = torch.randn(
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
)
# Float weights in w13 format [w1; w3]
w13 = (torch.randn(num_experts,
2 * intermediate_size,
hidden_size,
device=device,
dtype=torch.bfloat16) / 10)
w2 = (torch.randn(num_experts,
hidden_size,
intermediate_size,
device=device,
dtype=torch.bfloat16) / 10)
w13 = (
torch.randn(
num_experts,
2 * intermediate_size,
hidden_size,
device=device,
dtype=torch.bfloat16,
)
/ 10
)
w2 = (
torch.randn(
num_experts,
hidden_size,
intermediate_size,
device=device,
dtype=torch.bfloat16,
)
/ 10
)
# Bias contiguous [b1; b3]
bias13 = (torch.randn(num_experts,
2 * intermediate_size,
device=device,
dtype=torch.bfloat16) * 10)
bias2 = (torch.randn(
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
router_logits = torch.rand(num_tokens,
num_experts,
dtype=torch.float32,
device=device)
bias13 = (
torch.randn(
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
)
* 10
)
bias2 = (
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
)
router_logits = torch.rand(
num_tokens, num_experts, dtype=torch.float32, device=device
)
# Quantize weights to MXFP4 per expert (SM100 path)
from flashinfer import mxfp4_quantize
@@ -761,36 +862,56 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
sfs.append(sf)
return torch.stack(qs), torch.stack(sfs)
def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
scale_tensor: torch.Tensor):
def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
num_batches = mat_fp4.size(0)
scale_tensor = scale_tensor.view(num_batches, -1)
from flashinfer import mxfp4_dequantize
return torch.stack([
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
for b in range(num_batches)
])
return torch.stack(
[
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
for b in range(num_batches)
]
)
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
# Reference result using dequantized tensors and reference_moe
w13_ref = dequant_mxfp4_batches(
w13_q.view(torch.uint8),
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, 2 * intermediate_size, hidden_size).to(device)
w2_ref = dequant_mxfp4_batches(
w2_q.view(torch.uint8),
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
num_experts, hidden_size, intermediate_size).to(device)
w13_ref = (
dequant_mxfp4_batches(
w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)
)
.to(torch.float32)
.reshape(num_experts, 2 * intermediate_size, hidden_size)
.to(device)
)
w2_ref = (
dequant_mxfp4_batches(
w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)
)
.to(torch.float32)
.reshape(num_experts, hidden_size, intermediate_size)
.to(device)
)
# Quantize activations for SM100 path and dequantize for reference
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
hidden_states.to(torch.float32), w13_ref,
bias13.to(torch.float32), w2_ref,
bias2.to(torch.float32), alpha, beta, limit, 'mxfp8')
ref = reference_moe(
router_logits.to(torch.float32),
topk,
num_experts,
hidden_states.to(torch.float32),
w13_ref,
bias13.to(torch.float32),
w2_ref,
bias2.to(torch.float32),
alpha,
beta,
limit,
"mxfp8",
)
# Prepare inputs for FlashInfer CUTLASS fused MoE
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
@@ -807,31 +928,28 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
# Build routing for kernel
routing_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
token_final_scales, token_selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
token_final_scales = (token_final_scales /
token_final_scales.sum(dim=-1, keepdim=True))
routing_weights = torch.nn.functional.softmax(
router_logits, dim=1, dtype=torch.float32
)
token_final_scales, token_selected_experts = torch.topk(
routing_weights, topk, dim=-1
)
token_final_scales = token_final_scales / token_final_scales.sum(
dim=-1, keepdim=True
)
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
if alpha is not None:
alpha_t = torch.full((num_experts, ),
alpha,
device=hidden_states.device)
alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device)
else:
alpha_t = None
if beta is not None:
beta_t = torch.full((num_experts, ), beta, device=hidden_states.device)
beta_t = torch.full((num_experts,), beta, device=hidden_states.device)
else:
beta_t = None
if limit is not None:
limit_t = torch.full((num_experts, ),
limit,
device=hidden_states.device)
limit_t = torch.full((num_experts,), limit, device=hidden_states.device)
else:
limit_t = None

View File

@@ -4,9 +4,11 @@ import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
@@ -16,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
pytest.skip(
"Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True
)
MNK_FACTORS = [
(2, 1024, 1024),
@@ -38,36 +41,34 @@ MNK_FACTORS = [
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
def test_cutlass_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
quant_blocksize = 16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
(_, w1_q, w1_blockscale,
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_out_ch_quant=False,
)
(_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = (
make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_out_ch_quant=False,
)
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
assert w1_gs is not None
assert w2_gs is not None
@@ -97,40 +98,44 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize,
)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize,
)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output,
cutlass_output,
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
if __name__ == "__main__":

View File

@@ -9,13 +9,10 @@ import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils import cdiv
@@ -24,9 +21,13 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
from pplx_kernels.nvshmem import (
nvshmem_alloc_empty_unique_id,
nvshmem_finalize,
nvshmem_get_unique_id,
nvshmem_init,
)
has_pplx = True
except ImportError:
has_pplx = False
@@ -50,12 +51,12 @@ def chunk_by_rank(t, r, w):
chunk = rank_chunk(num, r, w)
rem = num % w
if rem == 0 or r < rem:
return t[(r * chunk):(r + 1) * chunk].contiguous()
return t[(r * chunk) : (r + 1) * chunk].contiguous()
else:
long_chunks = (num // w + 1) * rem
short_chunks = (r - rem) * chunk
start = long_chunks + short_chunks
return t[start:start + chunk].contiguous()
return t[start : start + chunk].contiguous()
def pplx_cutlass_moe(
@@ -75,7 +76,9 @@ def pplx_cutlass_moe(
group_name: Optional[str],
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
PplxPrepareAndFinalize,
)
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
@@ -126,35 +129,40 @@ def pplx_cutlass_moe(
ata,
max_num_tokens=max_num_tokens,
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)
num_dispatchers=num_dispatchers,
)
ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides1 = torch.full(
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
)
ab_strides2 = torch.full(
(num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64
)
c_strides1 = torch.full(
(num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64
)
c_strides2 = torch.full(
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
)
experts = CutlassBatchedExpertsFp8(
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
ab_strides2, c_strides1, c_strides2,
num_local_experts,
num_dispatchers,
out_dtype,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
fp8_w8a8_moe_quant_config(
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank]))
if per_act_token
else a1_scale[rank],
),
)
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
@@ -162,10 +170,10 @@ def pplx_cutlass_moe(
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weights, rank,
world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank,
world_size).to(torch.uint32).to(device)
chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device)
chunk_topk_ids = (
chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device)
)
out = fused_cutlass_experts(
a_chunk,
@@ -174,7 +182,7 @@ def pplx_cutlass_moe(
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
expert_map=None, # TODO
)
torch.cuda.synchronize()
@@ -210,35 +218,48 @@ def _pplx_moe(
):
try:
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
uid = (
nvshmem_get_unique_id()
if pgi.rank == 0
else nvshmem_alloc_empty_unique_id()
)
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
with set_current_vllm_config(vllm_config):
torch_output = torch_experts(a_full, w1_full, w2_full,
topk_weights, topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
per_out_ch, group_name)
torch_output = torch_experts(
a_full, w1_full, w2_full, topk_weights, topk_ids
)
pplx_output = pplx_cutlass_moe(
pgi,
dp_size,
a,
w1,
w2,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
a1_scale,
out_dtype,
per_act_token,
per_out_ch,
group_name,
)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
pplx_output.device
)
# Uncomment if more debugging is needed
# print("PPLX OUT:", pplx_output)
# print("TORCH OUT:", torch_output)
torch.testing.assert_close(pplx_output,
torch_output,
atol=0.05,
rtol=0)
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
finally:
if use_internode:
nvshmem_finalize()
@@ -251,13 +272,15 @@ def _pplx_moe(
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) # , [4, 2]])
@pytest.mark.parametrize("use_internode", [False])
@multi_gpu_test(num_gpus=2)
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
@requires_pplx
def test_cutlass_moe_pplx(
m: int,
@@ -273,7 +296,6 @@ def test_cutlass_moe_pplx(
current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config):
dtype = torch.half
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
@@ -283,22 +305,18 @@ def test_cutlass_moe_pplx(
n_b_scales = 2 * n if per_out_ch else 1
k_b_scales = k if per_out_ch else 1
w1_q = torch.empty((e, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn)
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_ch)
w1[expert], use_per_token_if_dynamic=per_out_ch
)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_ch)
w2[expert], use_per_token_if_dynamic=per_out_ch
)
w1_d = torch.empty_like(w1)
w2_d = torch.empty_like(w2)
@@ -307,19 +325,35 @@ def test_cutlass_moe_pplx(
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
world_size, dp_size = world_dp_size
a_scale1 = torch.randn(
(m if per_act_token else 1, 1), device="cuda",
dtype=torch.float32) / 10.0
a_scale1 = (
torch.randn(
(m if per_act_token else 1, 1), device="cuda", dtype=torch.float32
)
/ 10.0
)
if not per_act_token:
a_scale1 = a_scale1.repeat(world_size, 1)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
use_internode)
parallel_launch(
world_size,
_pplx_moe,
dp_size,
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
a_scale1,
dtype,
a,
w1_d,
w2_d,
per_act_token,
per_out_ch,
use_internode,
)

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import copy
import itertools
import textwrap
@@ -15,29 +16,34 @@ import torch
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
from pplx_kernels.nvshmem import (
nvshmem_alloc_empty_unique_id,
nvshmem_finalize,
nvshmem_get_unique_id,
nvshmem_init,
)
has_pplx = True
except ImportError:
has_pplx = False
from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
_set_vllm_config)
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
naive_batched_moe)
from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config
from tests.kernels.moe.utils import (
make_shared_experts,
make_test_weights,
naive_batched_moe,
)
from tests.kernels.quant_utils import dequant
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceDelegate,
)
from vllm.platforms import current_platform
from vllm.utils import round_up
@@ -59,7 +65,7 @@ BATCHED_MOE_MNK_FACTORS = [
PPLX_COMBOS = [
# TODO(bnell): figure out why this fails, seems to be test problem
#(1, 128, 128),
# (1, 128, 128),
(2, 128, 512),
(3, 1024, 2048),
(4, 128, 128),
@@ -91,17 +97,16 @@ def torch_prepare(
num_tokens, hidden_dim = a.shape
topk = topk_ids.shape[1]
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
assert tokens_per_expert.numel() == num_experts
if max_num_tokens is None:
max_num_tokens = int(tokens_per_expert.max().item())
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
dtype=a.dtype,
device=a.device)
b_a = torch.zeros(
(num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device
)
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
@@ -109,28 +114,29 @@ def torch_prepare(
for j in range(topk):
expert_id = topk_ids[token, j]
idx = token_counts[expert_id]
b_a[expert_id, idx:idx + 1, :] = a[token, :]
b_a[expert_id, idx : idx + 1, :] = a[token, :]
token_counts[expert_id] = token_counts[expert_id] + 1
return b_a, tokens_per_expert
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
topk_ids: torch.Tensor) -> torch.Tensor:
def torch_finalize(
b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor
) -> torch.Tensor:
num_tokens = topk_ids.shape[0]
num_experts = b_out.shape[0]
K = b_out.shape[-1]
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
expert_counts = torch.zeros(num_experts,
dtype=torch.int,
device=b_out.device)
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
for token in range(num_tokens):
expert_ids = topk_ids[token]
for i in range(expert_ids.numel()):
expert_id = expert_ids[i]
idx = expert_counts[expert_id]
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
1, :] * topk_weight[token, i]
out[token, :] = (
out[token, :]
+ b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i]
)
expert_counts[expert_id] = expert_counts[expert_id] + 1
return out
@@ -149,17 +155,18 @@ def torch_batched_moe(
num_tokens, topk = topk_ids.shape
_, max_num_tokens, K = b_a.shape
assert num_experts == b_a.shape[0] and w2.shape[1] == K
out = torch.zeros((num_experts, max_num_tokens, K),
dtype=b_a.dtype,
device=b_a.device)
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
dtype=b_a.dtype,
device=b_a.device)
out = torch.zeros(
(num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device
)
tmp = torch.empty(
(max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device
)
for expert in range(num_experts):
num = tokens_per_expert[expert]
if num > 0:
torch.ops._C.silu_and_mul(
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)
)
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
return torch_finalize(out, topk_weight, topk_ids)
@@ -186,20 +193,16 @@ def test_fused_moe_batched_experts(
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_experts(a, w1, w2, topk_weight,
topk_ids) # only for baseline
baseline_output = torch_experts(
a, w1, w2, topk_weight, topk_ids
) # only for baseline
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = naive_batched_moe(
a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this
a, w1, w2, topk_weight, topk_ids
) # pick torch_experts or this
torch.testing.assert_close(baseline_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_output,
batched_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0)
def create_pplx_prepare_finalize(
@@ -217,7 +220,9 @@ def create_pplx_prepare_finalize(
group_name: Optional[str],
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
PplxPrepareAndFinalize,
pplx_hidden_dim_scale_bytes,
)
max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
num_local_experts = rank_chunk(num_experts, 0, world_size)
@@ -266,28 +271,31 @@ def rank_chunk(num: int, r: int, w: int) -> int:
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
return t[(r * chunk) : (r + 1) * chunk]
def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int,
w: int) -> Optional[torch.Tensor]:
def maybe_chunk_by_rank(
t: Optional[torch.Tensor], r: int, w: int
) -> Optional[torch.Tensor]:
if t is not None:
return chunk_by_rank(t, r, w)
else:
return t
def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int,
w: int) -> Optional[torch.Tensor]:
def chunk_scales_by_rank(
t: Optional[torch.Tensor], r: int, w: int
) -> Optional[torch.Tensor]:
if t is not None and t.numel() > 1:
chunk = rank_chunk(t.shape[0], r, w)
return t[(r * chunk):(r + 1) * chunk]
return t[(r * chunk) : (r + 1) * chunk]
else:
return t
def chunk_scales(t: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
def chunk_scales(
t: Optional[torch.Tensor], start: int, end: int
) -> Optional[torch.Tensor]:
if t is not None and t.numel() > 1:
return t[start:end]
else:
@@ -350,8 +358,7 @@ def pplx_prepare_finalize(
device=device,
)
if (quant_dtype is not None and not per_act_token_quant
and block_shape is None):
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
else:
@@ -375,8 +382,7 @@ def pplx_prepare_finalize(
),
)
b_a = dummy_work(
dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
prepare_finalize.finalize(
out,
@@ -410,15 +416,17 @@ def _pplx_prepare_finalize(
):
try:
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
uid = (
nvshmem_get_unique_id()
if pgi.rank == 0
else nvshmem_alloc_empty_unique_id()
)
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
@@ -426,22 +434,28 @@ def _pplx_prepare_finalize(
a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
torch_output = (a_rep.view(m, topk, k) *
topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(
dim=1)
torch_output = (
a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype)
).sum(dim=1)
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight,
topk_ids, num_experts, quant_dtype,
block_shape, per_act_token_quant,
group_name)
pplx_output = pplx_prepare_finalize(
pgi,
dp_size,
a,
topk_weight,
topk_ids,
num_experts,
quant_dtype,
block_shape,
per_act_token_quant,
group_name,
)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pgi.device)
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
pgi.device
)
torch.testing.assert_close(pplx_output,
torch_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2)
finally:
if use_internode:
nvshmem_finalize()
@@ -491,9 +505,19 @@ def test_pplx_prepare_finalize_slow(
a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
score = torch.randn((m, e), device=device, dtype=act_dtype)
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
topk, e, quant_dtype, block_shape, per_act_token_quant,
use_internode)
parallel_launch(
world_size,
_pplx_prepare_finalize,
dp_size,
a,
score,
topk,
e,
quant_dtype,
block_shape,
per_act_token_quant,
use_internode,
)
def pplx_moe(
@@ -517,7 +541,6 @@ def pplx_moe(
use_cudagraphs: bool = True,
shared_experts: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0]
topk = topk_ids.shape[1]
@@ -579,21 +602,23 @@ def pplx_moe(
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
_fused_experts = torch.compile(
fused_experts, backend="inductor", fullgraph=True
)
torch._dynamo.mark_dynamic(a_chunk, 0)
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
else:
_fused_experts = fused_experts
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
out = _fused_experts(
a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts,
)
if use_cudagraphs:
if isinstance(out, tuple):
@@ -604,12 +629,14 @@ def pplx_moe(
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = _fused_experts(a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts)
out = _fused_experts(
a_chunk,
w1_chunk,
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts,
)
torch.cuda.synchronize()
graph.replay()
@@ -640,15 +667,17 @@ def _pplx_moe(
):
try:
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
uid = (
nvshmem_get_unique_id()
if pgi.rank == 0
else nvshmem_alloc_empty_unique_id()
)
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks,
backend="gloo")
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
m, k = a.shape
@@ -666,8 +695,7 @@ def _pplx_moe(
w1_s = w1_s.to(device) if w1_s is not None else None
w2_s = w2_s.to(device) if w2_s is not None else None
if (quant_dtype is not None and not per_act_token_quant
and block_shape is None):
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
else:
@@ -742,31 +770,27 @@ def _pplx_moe(
if shared_output is not None:
assert pplx_shared_output is not None
chunked_shared_output = chunk_by_rank(
shared_output, pgi.rank,
pgi.world_size).to(pplx_shared_output.device)
shared_output, pgi.rank, pgi.world_size
).to(pplx_shared_output.device)
else:
chunked_shared_output = None
chunked_batch_output = chunk_by_rank(
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
batched_output, pgi.rank, pgi.world_size
).to(pplx_output.device)
torch.testing.assert_close(batched_output,
torch_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
torch.testing.assert_close(pplx_output,
chunked_batch_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(
pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2
)
if shared_experts is not None:
assert chunked_shared_output is not None
assert pplx_shared_output is not None
torch.testing.assert_close(pplx_shared_output,
chunked_shared_output,
atol=3e-2,
rtol=3e-2)
torch.testing.assert_close(
pplx_shared_output, chunked_shared_output, atol=3e-2, rtol=3e-2
)
finally:
if use_internode:
@@ -823,15 +847,33 @@ def test_pplx_moe_slow(
per_out_ch_quant=per_act_token_quant,
)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
use_internode)
parallel_launch(
world_size,
_pplx_moe,
dp_size,
a,
w1,
w2,
score,
topk,
e,
w1_s,
w2_s,
quant_dtype,
per_act_token_quant,
block_shape,
use_internode,
)
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
use_shared_experts: bool, make_weights: bool,
test_fn: Callable):
def _pplx_test_loop(
pgi: ProcessGroupInfo,
dp_size: int,
use_internode: bool,
use_shared_experts: bool,
make_weights: bool,
test_fn: Callable,
):
def format_result(msg, ex=None):
if ex is not None:
x = str(ex)
@@ -850,12 +892,12 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
new_vllm_config = copy.deepcopy(vllm_config)
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
new_vllm_config.parallel_config.enable_expert_parallel = True
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
pgi.local_rank)
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, pgi.local_rank)
current_platform.seed_everything(7)
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
[False, True], [None, [128, 128]])
combos = itertools.product(
PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]
)
exceptions = []
count = 0
for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
@@ -873,13 +915,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
f"block_shape={block_shape}, use_internode={use_internode}, "
f"use_shared_experts={use_shared_experts}")
f"use_shared_experts={use_shared_experts}"
)
if not use_fp8_w8a8 and (per_act_token_quant
or block_shape is not None):
print(
f"{test_desc} - Skip quantization test for non-quantized type."
)
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
print(f"{test_desc} - Skip quantization test for non-quantized type.")
continue
if per_act_token_quant and block_shape is not None:
@@ -934,10 +974,10 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
f"rank={pgi.rank}."
)
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@@ -950,8 +990,15 @@ def test_pplx_prepare_finalize(
):
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
use_internode, False, False, _pplx_prepare_finalize)
parallel_launch(
world_size * dp_size,
_pplx_test_loop,
dp_size,
use_internode,
False,
False,
_pplx_prepare_finalize,
)
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@@ -966,5 +1013,12 @@ def test_pplx_moe(
):
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
use_shared_experts, True, _pplx_moe)
parallel_launch(
world_size,
_pplx_test_loop,
dp_size,
use_internode,
use_shared_experts,
True,
_pplx_moe,
)

View File

@@ -24,13 +24,14 @@ aiter_available = importlib.util.find_spec("aiter") is not None
pytestmark = pytest.mark.skipif(
not (current_platform.is_rocm() and aiter_available),
reason="AITER ops are only available on ROCm with aiter package installed")
reason="AITER ops are only available on ROCm with aiter package installed",
)
def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk')
assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk")
# Check if the op is callable
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
@@ -39,7 +40,7 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
def test_rocm_aiter_grouped_topk_custom_op_registration():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk')
assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk")
# Check if the op is callable
assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)
@@ -56,25 +57,29 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
renormalize = True
scale_factor = 1.0
gating_output = torch.randn((token, expert),
dtype=torch.bfloat16,
device="cuda")
e_score_correction_bias = torch.randn((expert, ),
dtype=torch.bfloat16,
device="cuda")
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
e_score_correction_bias = torch.randn(
(expert,), dtype=torch.bfloat16, device="cuda"
)
device = gating_output.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
# Define a function that uses the op
def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights, topk_ids):
def biased_grouped_topk_fn(
gating_output, e_score_correction_bias, topk_weights, topk_ids
):
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output, e_score_correction_bias, topk_weights, topk_ids,
num_expert_group, topk_group, renormalize, scale_factor)
gating_output,
e_score_correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
scale_factor,
)
# Verify the op's fake implementation
torch.library.opcheck(
@@ -84,51 +89,49 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"routed_scaling_factor": scale_factor
"routed_scaling_factor": scale_factor,
},
test_utils=("test_faketensor"))
test_utils=("test_faketensor"),
)
# Compile the function with appropriate settings
compiled_fn = torch.compile(biased_grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)
compiled_fn = torch.compile(
biased_grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False,
)
topk_weights_original = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_original = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_original = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights_compiled = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_compiled = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_compiled = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights_original, topk_ids_original)
compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled,
topk_ids_compiled)
biased_grouped_topk_fn(
gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original
)
compiled_fn(
gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled
)
# Sort the results for comparison since the order might not be deterministic
topk_ids_original, indices_original = torch.sort(topk_ids_original)
topk_weights_original = torch.gather(topk_weights_original, 1,
indices_original)
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
indices_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
# Verify results match
assert torch.allclose(topk_weights_original,
topk_weights_compiled,
rtol=1e-2,
atol=1e-2)
assert torch.allclose(
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
)
assert torch.allclose(topk_ids_original, topk_ids_compiled)
@@ -144,73 +147,73 @@ def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
scoring_func = "softmax"
scale_factor = 1.0
gating_output = torch.randn((token, expert),
dtype=torch.bfloat16,
device="cuda")
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
device = gating_output.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
# Define a function that uses the op
def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
return torch.ops.vllm.rocm_aiter_grouped_topk(
gating_output, topk_weights, topk_ids, num_expert_group,
topk_group, renormalize, scoring_func, scale_factor)
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
scoring_func,
scale_factor,
)
# Verify the op's fake implementation
torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk,
(gating_output, topk_weights, topk_ids),
kwargs={
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"scoring_func": scoring_func,
"routed_scaling_factor": scale_factor
},
test_utils=("test_faketensor"))
torch.library.opcheck(
torch.ops.vllm.rocm_aiter_grouped_topk,
(gating_output, topk_weights, topk_ids),
kwargs={
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"scoring_func": scoring_func,
"routed_scaling_factor": scale_factor,
},
test_utils=("test_faketensor"),
)
# Compile the function with appropriate settings
compiled_fn = torch.compile(grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)
compiled_fn = torch.compile(
grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False,
)
topk_weights_original = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_original = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_original = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights_compiled = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_compiled = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_compiled = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original,
scoring_func)
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled,
scoring_func)
grouped_topk_fn(
gating_output, topk_weights_original, topk_ids_original, scoring_func
)
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func)
# Sort the results for comparison since the order might not be deterministic
topk_ids_original, indices_original = torch.sort(topk_ids_original)
topk_weights_original = torch.gather(topk_weights_original, 1,
indices_original)
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
indices_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
# Verify results match
assert torch.allclose(topk_weights_original,
topk_weights_compiled,
rtol=1e-2,
atol=1e-2)
assert torch.allclose(
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
)
assert torch.allclose(topk_ids_original, topk_ids_compiled)

View File

@@ -5,7 +5,8 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm_cuda)
silu_mul_fp8_quant_deep_gemm_cuda,
)
from vllm.platforms import current_platform
from vllm.utils import cdiv
@@ -34,7 +35,6 @@ CASES = [
(256, 16, 7168, fp8_dtype),
(256, 32, 7168, fp8_dtype),
(256, 64, 7168, fp8_dtype),
# Only add a few fnuz tests to help with long CI times.
(8, 512, 7168, torch.float8_e4m3fnuz),
(8, 1024, 7168, torch.float8_e4m3fnuz),
@@ -52,15 +52,15 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
tokens_per_expert = torch.randint(
low=T // 2,
high=T,
size=(E, ),
size=(E,),
dtype=torch.int32,
device="cuda",
)
# Run the Triton kernel
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y,
tokens_per_expert,
group_size=group_size)
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(
y, tokens_per_expert, group_size=group_size
)
torch.cuda.synchronize()
fp8_info = torch.finfo(fp8_dtype)
@@ -75,9 +75,9 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty((T, cdiv(H, group_size)),
dtype=torch.float32,
device="cuda")
ref_s = torch.empty(
(T, cdiv(H, group_size)), dtype=torch.float32, device="cuda"
)
ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda")
for t in range(nt):
@@ -87,14 +87,17 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
# process full groups
n_full_groups = H // group_size
if n_full_groups > 0:
data_grp = data[:n_full_groups * group_size].view(
n_full_groups, group_size)
data_grp = data[: n_full_groups * group_size].view(
n_full_groups, group_size
)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max
scaled = data[:n_full_groups *
group_size] / scale.repeat_interleave(group_size)
ref_q_row[:n_full_groups * group_size] = scaled.clamp(
fp8_min, fp8_max).to(fp8_dtype)
scaled = data[: n_full_groups * group_size] / scale.repeat_interleave(
group_size
)
ref_q_row[: n_full_groups * group_size] = scaled.clamp(
fp8_min, fp8_max
).to(fp8_dtype)
ref_s[t, :n_full_groups] = scale
# process remainder group

View File

@@ -11,13 +11,11 @@ from tests.kernels.moe.utils import fused_moe
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -31,14 +29,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(
), "B must be a 2D contiguous tensor"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K, )
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
@@ -88,17 +85,17 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = ops.scaled_fp8_quant(
act_out, use_per_token_if_dynamic=True)
act_out, use_per_token_if_dynamic=True
)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
output_dtype=a.dtype)
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
@@ -116,8 +113,10 @@ TOP_KS = [2, 6]
SEEDS = [0]
@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS))
@pytest.mark.parametrize(
"M, N, K, E, topk, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
torch.manual_seed(seed)
@@ -133,12 +132,10 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
# Generate int8 weights
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min,
max=fp8_max).to(torch.float8_e4m3fn)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min,
max=fp8_max).to(torch.float8_e4m3fn)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
# Generate scale for each column (per-column quantization)
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
@@ -163,7 +160,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
)
# Check results
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.05

View File

@@ -6,17 +6,17 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
BatchedPrepareAndFinalize,
BatchedTritonExperts,
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils import round_up
from vllm.utils.deep_gemm import per_block_cast_to_fp8
@@ -45,12 +45,7 @@ def triton_moe(
a2_scale=a2_scale,
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
quant_config=quant_config)
return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
def batched_moe(
@@ -80,10 +75,9 @@ def batched_moe(
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
BatchedPrepareAndFinalize(
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
),
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
@@ -121,10 +115,9 @@ def naive_batched_moe(
)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
num_dispatchers=1,
num_local_experts=w1.shape[0],
rank=0),
BatchedPrepareAndFinalize(
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
),
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
@@ -135,8 +128,9 @@ def naive_batched_moe(
return fused_experts(a, w1, w2, topk_weight, topk_ids)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
def chunk_scales(
scales: Optional[torch.Tensor], start: int, end: int
) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
@@ -159,13 +153,15 @@ def make_quantized_test_activations(
a_scale = None
if quant_dtype is not None:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
"only fp8/int8 supported"
)
a_q = torch.zeros_like(a, dtype=quant_dtype)
a_scale_l = [None] * E
for e in range(E):
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
a[e], None, quant_dtype, per_act_token_quant, block_shape)
a[e], None, quant_dtype, per_act_token_quant, block_shape
)
a_scale = torch.stack(a_scale_l)
if not per_act_token_quant and block_shape is None:
@@ -181,8 +177,11 @@ def moe_quantize_weights(
per_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
assert (
quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8
or quant_dtype == "nvfp4"
), "only fp8/int8/nvfp4 supported"
w_gs = None
@@ -199,10 +198,12 @@ def moe_quantize_weights(
else:
if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
w, w_s, use_per_token_if_dynamic=per_token_quant
)
elif quant_dtype == torch.float8_e4m3fn:
w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
w, w_s, use_per_token_if_dynamic=per_token_quant
)
elif quant_dtype == "nvfp4":
assert not per_token_quant
w_amax = torch.abs(w).max().to(torch.float32)
@@ -222,8 +223,7 @@ def make_test_weight(
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_out_ch_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
w_gs = None
@@ -233,7 +233,8 @@ def make_test_weight(
w_gs_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
@@ -264,26 +265,25 @@ def make_test_weights(
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_out_ch_quant: bool = False,
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
) -> tuple[
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
]:
return (
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_out_ch_quant),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_out_ch_quant),
make_test_weight(
e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant
),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
)
def per_token_cast_to_fp8(
x: torch.Tensor,
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
x: torch.Tensor, block_size: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
@@ -313,27 +313,31 @@ def make_test_quant_config(
a1_gscale: Optional[torch.Tensor] = None
a2_gscale: Optional[torch.Tensor] = None
if quant_dtype == "nvfp4":
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
a1_scale = a1_gscale
a2_scale = a2_gscale
else:
a1_scale = None
a2_scale = None
return w1, w2, FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_s,
w2_scale=w2_s,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
a1_scale=a1_scale,
a2_scale=a2_scale,
# TODO: make sure this is handled properly
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
return (
w1,
w2,
FusedMoEQuantConfig.make(
quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
w1_scale=w1_s,
w2_scale=w2_s,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
a1_scale=a1_scale,
a2_scale=a2_scale,
# TODO: make sure this is handled properly
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
),
)
@@ -348,21 +352,23 @@ def fused_moe(
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=quant_config)
topk_weights, topk_ids, _ = fused_topk(
hidden_states, score.float(), topk, renormalize
)
return fused_experts(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=quant_config,
)
# CustomOp?
class BaselineMM(torch.nn.Module):
def __init__(
self,
b: torch.Tensor,
@@ -372,15 +378,11 @@ class BaselineMM(torch.nn.Module):
self.b = b.to(dtype=torch.float32)
self.out_dtype = out_dtype
def forward(
self,
a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return torch.mm(a.to(dtype=torch.float32),
self.b).to(self.out_dtype), None
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
class TestMLP(torch.nn.Module):
def __init__(
self,
w1: torch.Tensor,
@@ -410,7 +412,6 @@ def make_naive_shared_experts(
class RealMLP(torch.nn.Module):
def __init__(
self,
hidden_size: int,
@@ -425,37 +426,48 @@ class RealMLP(torch.nn.Module):
w2_s: Optional[torch.Tensor] = None,
) -> None:
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, RowParallelLinear)
MergedColumnParallelLinear,
RowParallelLinear,
)
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
)
self.gate_up_proj.register_parameter(
"weight", torch.nn.Parameter(w1, requires_grad=False))
"weight", torch.nn.Parameter(w1, requires_grad=False)
)
self.gate_up_proj.register_parameter(
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False))
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
)
self.gate_up_proj.register_parameter(
"input_scale",
None) #torch.nn.Parameter(None, requires_grad=False))
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
"input_scale", None
) # torch.nn.Parameter(None, requires_grad=False))
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
self.down_proj.register_parameter(
"weight", torch.nn.Parameter(w2, requires_grad=False))
"weight", torch.nn.Parameter(w2, requires_grad=False)
)
self.down_proj.register_parameter(
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False))
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
)
self.down_proj.register_parameter(
"input_scale",
None) #torch.nn.Parameter(None, requires_grad=False))
"input_scale", None
) # torch.nn.Parameter(None, requires_grad=False))
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
@@ -496,13 +508,6 @@ def make_shared_experts(
w2_s = None
quant_config = None
return RealMLP(K,
N,
w1,
w2,
"silu",
quant_config,
w1_s=w1_s,
w2_s=w2_s)
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
finally:
torch.set_default_dtype(old_dtype)