Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,18 +9,19 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
|
||||
from .common import Config
|
||||
from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
from .mk_objects import (
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES,
|
||||
MK_FUSED_EXPERT_TYPES,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
|
||||
)
|
||||
|
||||
|
||||
def make_config_arg_parser(description: str):
|
||||
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
|
||||
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
|
||||
if pf.__name__ == s:
|
||||
return pf
|
||||
raise ValueError(
|
||||
f"Cannot find a PrepareFinalize type that matches {s}")
|
||||
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
|
||||
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
for fe in MK_FUSED_EXPERT_TYPES:
|
||||
@@ -45,15 +46,18 @@ def make_config_arg_parser(description: str):
|
||||
"--pf-type",
|
||||
type=to_pf_class_type,
|
||||
required=True,
|
||||
help=("Choose a PrepareFinalize Type : "
|
||||
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"),
|
||||
help=(
|
||||
"Choose a PrepareFinalize Type : "
|
||||
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experts-type",
|
||||
type=to_experts_class_type,
|
||||
required=True,
|
||||
help=(f"Choose a FusedExpert type : "
|
||||
f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"),
|
||||
help=(
|
||||
f"Choose a FusedExpert type : {[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
@@ -74,66 +78,65 @@ def make_config_arg_parser(description: str):
|
||||
default=1024,
|
||||
help="N dimension of the first fused-moe matmul",
|
||||
)
|
||||
parser.add_argument("--num-experts",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Global num experts")
|
||||
parser.add_argument("--topk",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[4, 1],
|
||||
help="num topk")
|
||||
parser.add_argument(
|
||||
"--num-experts", type=int, default=32, help="Global num experts"
|
||||
)
|
||||
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
|
||||
parser.add_argument(
|
||||
"--fused-moe-chunk-size",
|
||||
type=int,
|
||||
help="Fused moe chunk size used for the non-batched fused experts impl."
|
||||
help="Fused moe chunk size used for the non-batched fused experts impl.",
|
||||
)
|
||||
|
||||
# Quant args
|
||||
parser.add_argument("--quant-dtype",
|
||||
type=to_quant_torch_dtype,
|
||||
help="Quant datatype")
|
||||
parser.add_argument("--per-token-quantized-activations",
|
||||
action='store_true',
|
||||
help=("The input activations must be per-token "
|
||||
"quantized"))
|
||||
parser.add_argument("--per-channel-quantized-weights",
|
||||
action="store_true",
|
||||
help="The weights must be per-channel quantized.")
|
||||
parser.add_argument("--block-shape",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Quantization block shape")
|
||||
parser.add_argument(
|
||||
"--quant-dtype", type=to_quant_torch_dtype, help="Quant datatype"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-token-quantized-activations",
|
||||
action="store_true",
|
||||
help=("The input activations must be per-token quantized"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-channel-quantized-weights",
|
||||
action="store_true",
|
||||
help="The weights must be per-channel quantized.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block-shape", nargs="+", type=int, help="Quantization block shape"
|
||||
)
|
||||
|
||||
# Torch trace profile generation args
|
||||
parser.add_argument("--torch-trace-dir-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Get torch trace for single execution")
|
||||
parser.add_argument(
|
||||
"--torch-trace-dir-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Get torch trace for single execution",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _validate_args(args: argparse.Namespace):
|
||||
|
||||
if args.quant_dtype is not None:
|
||||
assert args.quant_dtype == torch.float8_e4m3fn
|
||||
if args.block_shape is not None:
|
||||
assert len(args.block_shape) == 2, (
|
||||
f"block shape must have 2 elements. got {args.block_shape}")
|
||||
f"block shape must have 2 elements. got {args.block_shape}"
|
||||
)
|
||||
|
||||
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
|
||||
assert args.world_size == 1, (
|
||||
"Single GPU objects need world size set to 1")
|
||||
assert args.world_size == 1, "Single GPU objects need world size set to 1"
|
||||
|
||||
if args.torch_trace_dir_path is not None:
|
||||
from pathlib import Path
|
||||
|
||||
assert Path(args.torch_trace_dir_path).is_dir(), (
|
||||
f"Please create {args.torch_trace_dir_path}")
|
||||
f"Please create {args.torch_trace_dir_path}"
|
||||
)
|
||||
|
||||
|
||||
def make_config(args: argparse.Namespace) -> Config:
|
||||
|
||||
_validate_args(args)
|
||||
|
||||
quant_config = None
|
||||
@@ -142,7 +145,8 @@ def make_config(args: argparse.Namespace) -> Config:
|
||||
quant_dtype=args.quant_dtype,
|
||||
per_act_token_quant=args.per_token_quantized_activations,
|
||||
per_out_ch_quant=args.per_channel_quantized_weights,
|
||||
block_shape=args.block_shape)
|
||||
block_shape=args.block_shape,
|
||||
)
|
||||
|
||||
return Config(
|
||||
Ms=args.m,
|
||||
@@ -156,4 +160,5 @@ def make_config(args: argparse.Namespace) -> Config:
|
||||
fused_experts_type=args.experts_type,
|
||||
fused_moe_chunk_size=args.fused_moe_chunk_size,
|
||||
world_size=args.world_size,
|
||||
torch_trace_dir_path=args.torch_trace_dir_path)
|
||||
torch_trace_dir_path=args.torch_trace_dir_path,
|
||||
)
|
||||
|
||||
@@ -8,20 +8,30 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
|
||||
make_prepare_finalize, prepare_finalize_info)
|
||||
from .mk_objects import (
|
||||
TestMoEQuantConfig,
|
||||
expert_info,
|
||||
make_fused_experts,
|
||||
make_prepare_finalize,
|
||||
prepare_finalize_info,
|
||||
)
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
|
||||
|
||||
@@ -94,8 +104,7 @@ class Config:
|
||||
|
||||
@property
|
||||
def is_per_tensor_act_quant(self) -> bool:
|
||||
return (not self.is_per_act_token_quant
|
||||
and self.quant_block_shape is None)
|
||||
return not self.is_per_act_token_quant and self.quant_block_shape is None
|
||||
|
||||
@property
|
||||
def is_per_out_ch_quant(self) -> bool:
|
||||
@@ -134,23 +143,24 @@ class Config:
|
||||
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
|
||||
)
|
||||
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
return (self.quant_dtype == torch.float8_e4m3fn
|
||||
and self.quant_block_shape is not None)
|
||||
return (
|
||||
self.quant_dtype == torch.float8_e4m3fn
|
||||
and self.quant_block_shape is not None
|
||||
)
|
||||
|
||||
def is_batched_prepare_finalize(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts ==
|
||||
info.activation_format)
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
|
||||
|
||||
def is_batched_fused_experts(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts ==
|
||||
info.activation_format)
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format
|
||||
|
||||
def is_standard_fused_experts(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
@@ -190,8 +200,10 @@ class Config:
|
||||
|
||||
def needs_deep_ep(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return (info.backend == "deepep_high_throughput"
|
||||
or info.backend == "deepep_low_latency")
|
||||
return (
|
||||
info.backend == "deepep_high_throughput"
|
||||
or info.backend == "deepep_low_latency"
|
||||
)
|
||||
|
||||
def all2all_backend(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
@@ -211,20 +223,26 @@ class Config:
|
||||
return False
|
||||
|
||||
# Check quantization sanity
|
||||
if (int(self.is_per_act_token_quant) +
|
||||
int(self.is_per_tensor_act_quant) +
|
||||
int(self.quant_block_shape is not None)) > 1:
|
||||
if (
|
||||
int(self.is_per_act_token_quant)
|
||||
+ int(self.is_per_tensor_act_quant)
|
||||
+ int(self.quant_block_shape is not None)
|
||||
) > 1:
|
||||
# invalid quant config
|
||||
return False
|
||||
|
||||
# check type support
|
||||
if self.quant_dtype is None:
|
||||
if (self.dtype not in self.pf_supported_types()
|
||||
or self.dtype not in self.fe_supported_types()):
|
||||
if (
|
||||
self.dtype not in self.pf_supported_types()
|
||||
or self.dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False
|
||||
else:
|
||||
if (self.quant_dtype not in self.pf_supported_types()
|
||||
or self.quant_dtype not in self.fe_supported_types()):
|
||||
if (
|
||||
self.quant_dtype not in self.pf_supported_types()
|
||||
or self.quant_dtype not in self.fe_supported_types()
|
||||
):
|
||||
return False
|
||||
|
||||
# Check block quanization support
|
||||
@@ -261,18 +279,21 @@ class WeightTensors:
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Weight Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.w1, "w1")} \n'
|
||||
s += f' - {_describe_tensor(self.w2, "w2")} \n'
|
||||
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
|
||||
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
|
||||
s += f" - {_describe_tensor(self.w1, 'w1')} \n"
|
||||
s += f" - {_describe_tensor(self.w2, 'w2')} \n"
|
||||
s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n"
|
||||
s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n"
|
||||
return s
|
||||
|
||||
def is_quantized(self) -> bool:
|
||||
# or w1_scale is not None?
|
||||
return (self.w1.dtype == torch.float8_e4m3fn
|
||||
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
|
||||
return (
|
||||
self.w1.dtype == torch.float8_e4m3fn
|
||||
or self.w1.dtype == torch.uint8
|
||||
or self.w1.dtype == torch.int8
|
||||
)
|
||||
|
||||
def to_current_device(self):
|
||||
device = torch.cuda.current_device()
|
||||
@@ -289,16 +310,13 @@ class WeightTensors:
|
||||
if self.w2_gs is not None:
|
||||
self.w2_gs = self.w2_gs.to(device=device)
|
||||
|
||||
def slice_weights(self, rank: int,
|
||||
num_local_experts: int) -> "WeightTensors":
|
||||
def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors":
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
w1_scale = self.w1_scale[
|
||||
s:e, :, :] if self.w1_scale is not None else None
|
||||
w2_scale = self.w2_scale[
|
||||
s:e, :, :] if self.w2_scale is not None else None
|
||||
w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None
|
||||
w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None
|
||||
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
|
||||
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
|
||||
|
||||
@@ -313,15 +331,11 @@ class WeightTensors:
|
||||
in_dtype=config.dtype,
|
||||
quant_dtype=config.quant_dtype,
|
||||
block_shape=config.quant_block_shape,
|
||||
per_out_ch_quant=config.
|
||||
is_per_act_token_quant, # or config.is_per_out_ch_quant
|
||||
per_out_ch_quant=config.is_per_act_token_quant, # or config.is_per_out_ch_quant
|
||||
)
|
||||
return WeightTensors(
|
||||
w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_gs=w1_gs,
|
||||
w2_gs=w2_gs)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -336,22 +350,22 @@ class RankTensors:
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Rank Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.hidden_states, "HS")} \n'
|
||||
s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n'
|
||||
s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n'
|
||||
s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n"
|
||||
s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n"
|
||||
s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n"
|
||||
s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n"
|
||||
s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n"
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def make_hidden_states(
|
||||
config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
config: Config,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Return hidden_states
|
||||
"""
|
||||
m, k, dtype = (config.M, config.K, config.dtype)
|
||||
a = (torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0)
|
||||
a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0
|
||||
|
||||
if config.quant_dtype is None:
|
||||
return a, None
|
||||
@@ -362,36 +376,29 @@ class RankTensors:
|
||||
# first - so further quantize and dequantize will yield the same
|
||||
# values.
|
||||
if config.is_per_tensor_act_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=False)
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False)
|
||||
return a_q.float().mul(a_scales).to(dtype), a_scales
|
||||
|
||||
if config.is_per_act_token_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a,
|
||||
use_per_token_if_dynamic=True)
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
return a_q.float().mul(a_scales).to(dtype), None
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
block_k = config.quant_block_shape[1]
|
||||
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
|
||||
return a_q.float().view(
|
||||
(-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None
|
||||
return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(
|
||||
dtype
|
||||
), None
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config, pgi: ProcessGroupInfo):
|
||||
|
||||
dtype = config.dtype
|
||||
topk, m, _ = (config.topk, config.M, config.K)
|
||||
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(
|
||||
config)
|
||||
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config)
|
||||
|
||||
num_local_experts, global_num_experts = (config.num_local_experts,
|
||||
config.E)
|
||||
score = torch.randn((m, global_num_experts),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
|
||||
False)
|
||||
num_local_experts, global_num_experts = (config.num_local_experts, config.E)
|
||||
score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False)
|
||||
|
||||
# distribute topk_ids evenly
|
||||
for mi in range(m):
|
||||
@@ -400,14 +407,15 @@ class RankTensors:
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1 and config.supports_expert_map():
|
||||
expert_map = torch.full((global_num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.full(
|
||||
(global_num_experts,), fill_value=-1, dtype=torch.int32
|
||||
)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
expert_map = expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
expert_map = expert_map.to(
|
||||
device=torch.cuda.current_device(), dtype=torch.int32
|
||||
)
|
||||
|
||||
return RankTensors(
|
||||
hidden_states=hidden_states,
|
||||
@@ -418,9 +426,9 @@ class RankTensors:
|
||||
)
|
||||
|
||||
|
||||
def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
rank_tensors: RankTensors) -> torch.Tensor:
|
||||
|
||||
def reference_moe_impl(
|
||||
config: Config, weights: WeightTensors, rank_tensors: RankTensors
|
||||
) -> torch.Tensor:
|
||||
if config.quant_dtype == "nvfp4":
|
||||
quant_blocksize = 16
|
||||
dtype = config.dtype
|
||||
@@ -433,8 +441,10 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
w2_blockscale = weights.w2_scale
|
||||
w2_gs = weights.w2_gs
|
||||
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
|
||||
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX)
|
||||
/ torch.amax(rank_tensors.hidden_states.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
@@ -447,14 +457,17 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
assert w2_blockscale.shape[2] % 4 == 0
|
||||
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
|
||||
rank_tensors.hidden_states, a_global_scale)
|
||||
rank_tensors.hidden_states, a_global_scale
|
||||
)
|
||||
|
||||
a = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=a_fp4.device,
|
||||
block_size=quant_blocksize)
|
||||
a = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=a_fp4.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
e = w1_q.shape[0]
|
||||
n = w1_q.shape[1] // 2
|
||||
@@ -464,18 +477,22 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w1[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
a_scale = None
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
@@ -493,27 +510,29 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
per_act_token_quant = config.is_per_act_token_quant
|
||||
block_shape = config.quant_block_shape
|
||||
|
||||
return torch_experts(a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1
|
||||
and config.supports_apply_weight_on_input())
|
||||
return torch_experts(
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1
|
||||
and config.supports_apply_weight_on_input(),
|
||||
)
|
||||
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones((num_experts, ),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float32)
|
||||
return torch.ones(
|
||||
(num_experts,), device=torch.cuda.current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
@@ -521,12 +540,12 @@ def make_modular_kernel(
|
||||
vllm_config: VllmConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2**math.ceil(math.log2(x))
|
||||
return 2 ** math.ceil(math.log2(x))
|
||||
|
||||
# make moe config
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
@@ -546,9 +565,9 @@ def make_modular_kernel(
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
|
||||
config.all2all_backend(), moe,
|
||||
quant_config)
|
||||
prepare_finalize = make_prepare_finalize(
|
||||
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
|
||||
)
|
||||
|
||||
fused_experts = make_fused_experts(
|
||||
config.fused_experts_type,
|
||||
@@ -559,7 +578,8 @@ def make_modular_kernel(
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
|
||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts
|
||||
)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
@@ -587,10 +607,8 @@ def run_modular_kernel(
|
||||
w1_scale=rank_weights.w1_scale,
|
||||
w2_scale=rank_weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
g1_alphas=(1 / rank_weights.w1_gs)
|
||||
if rank_weights.w1_gs is not None else None,
|
||||
g2_alphas=(1 / rank_weights.w2_gs)
|
||||
if rank_weights.w2_gs is not None else None,
|
||||
g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None,
|
||||
g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None,
|
||||
a1_gscale=gscale,
|
||||
a2_gscale=gscale,
|
||||
block_shape=config.quant_block_shape,
|
||||
@@ -603,38 +621,30 @@ def run_modular_kernel(
|
||||
# impls might update the tensor in place
|
||||
hidden_states = rank_tensors.hidden_states.clone()
|
||||
|
||||
topk_ids = rank_tensors.topk_ids.to(
|
||||
mk.prepare_finalize.topk_indices_dtype())
|
||||
topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype())
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states":
|
||||
hidden_states,
|
||||
"w1":
|
||||
rank_weights.w1,
|
||||
"w2":
|
||||
rank_weights.w2,
|
||||
"topk_weights":
|
||||
rank_tensors.topk_weights,
|
||||
"topk_ids":
|
||||
topk_ids,
|
||||
"expert_map":
|
||||
rank_tensors.expert_map,
|
||||
"global_num_experts":
|
||||
config.E,
|
||||
"apply_router_weight_on_input":
|
||||
config.topk == 1 and config.supports_apply_weight_on_input(),
|
||||
"hidden_states": hidden_states,
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1
|
||||
and config.supports_apply_weight_on_input(),
|
||||
}
|
||||
|
||||
num_tokens = rank_tensors.hidden_states.shape[0]
|
||||
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
|
||||
device="cuda",
|
||||
dtype=torch.int)
|
||||
num_tokens_across_dp = torch.tensor(
|
||||
[num_tokens] * config.world_size, device="cuda", dtype=torch.int
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
out = mk.forward(**mk_kwargs)
|
||||
|
||||
|
||||
@@ -10,14 +10,21 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||
from vllm.model_executor.layers.fused_moe.config import FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
|
||||
run_modular_kernel)
|
||||
from .mk_objects import (MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS)
|
||||
from .common import (
|
||||
Config,
|
||||
RankTensors,
|
||||
WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel,
|
||||
)
|
||||
from .mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS,
|
||||
)
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
@@ -38,8 +45,9 @@ def rank_worker(
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
@@ -60,8 +68,7 @@ def rank_worker(
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
rank_tensors)
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
@@ -70,28 +77,27 @@ def rank_worker(
|
||||
|
||||
|
||||
def make_feature_matrix(csv_file_path: str):
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
def add_to_results(config: Config,
|
||||
success: Result,
|
||||
results_df: Optional[pd.DataFrame] = None):
|
||||
def add_to_results(
|
||||
config: Config, success: Result, results_df: Optional[pd.DataFrame] = None
|
||||
):
|
||||
config_dict = asdict(config)
|
||||
config_dict['prepare_finalize_type'] = config_dict[
|
||||
'prepare_finalize_type'].__name__
|
||||
config_dict['fused_experts_type'] = config_dict[
|
||||
'fused_experts_type'].__name__
|
||||
config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant
|
||||
quant_config_dict = config_dict['quant_config']
|
||||
del config_dict['quant_config']
|
||||
config_dict["prepare_finalize_type"] = config_dict[
|
||||
"prepare_finalize_type"
|
||||
].__name__
|
||||
config_dict["fused_experts_type"] = config_dict["fused_experts_type"].__name__
|
||||
config_dict["per_tensor_act_quant"] = config.is_per_tensor_act_quant
|
||||
quant_config_dict = config_dict["quant_config"]
|
||||
del config_dict["quant_config"]
|
||||
if quant_config_dict is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
quant_config_dict = asdict(quant_config)
|
||||
|
||||
config_dict |= quant_config_dict
|
||||
result_dict = config_dict | {'success': success.name}
|
||||
result_dict = config_dict | {"success": success.name}
|
||||
|
||||
result_df = pd.DataFrame([result_dict])
|
||||
if results_df is None:
|
||||
@@ -112,22 +118,26 @@ def make_feature_matrix(csv_file_path: str):
|
||||
Q_TYPES = MK_QUANT_CONFIGS
|
||||
|
||||
combinations = list(
|
||||
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES))
|
||||
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)
|
||||
)
|
||||
|
||||
results_df: Optional[pd.DataFrame] = None
|
||||
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
|
||||
combinations): #noqa: E501
|
||||
config = Config(Ms=[m],
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=topks,
|
||||
dtype=dtype,
|
||||
prepare_finalize_type=pf_type,
|
||||
fused_experts_type=experts_type,
|
||||
quant_config=quant_config,
|
||||
world_size=2,
|
||||
fused_moe_chunk_size=None)
|
||||
combinations
|
||||
): # noqa: E501
|
||||
config = Config(
|
||||
Ms=[m],
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=topks,
|
||||
dtype=dtype,
|
||||
prepare_finalize_type=pf_type,
|
||||
fused_experts_type=experts_type,
|
||||
quant_config=quant_config,
|
||||
world_size=2,
|
||||
fused_moe_chunk_size=None,
|
||||
)
|
||||
|
||||
success = None
|
||||
if config.is_valid():
|
||||
@@ -135,9 +145,14 @@ def make_feature_matrix(csv_file_path: str):
|
||||
try:
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker,
|
||||
vllm_config, env_dict, config,
|
||||
weights)
|
||||
parallel_launch_with_config(
|
||||
config.world_size,
|
||||
rank_worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
success = Result.PASS
|
||||
except Exception as _:
|
||||
success = Result.FAIL
|
||||
@@ -150,25 +165,33 @@ def make_feature_matrix(csv_file_path: str):
|
||||
results_df.to_csv(f"{csv_file_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
"Make ModularKernel feature matrix \n"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501
|
||||
"-f ./feature_matrices/feature_matrix.csv"))
|
||||
|
||||
parser.add_argument("-f",
|
||||
"--feature-matrix-csv-file-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="File name to Generate a .csv file")
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Make ModularKernel feature matrix \n"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " # noqa: E501
|
||||
"-f ./feature_matrices/feature_matrix.csv"
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--feature-matrix-csv-file-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="File name to Generate a .csv file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
csv_path = args.feature_matrix_csv_file_path
|
||||
assert csv_path.endswith(
|
||||
'csv'), f"Need a file path ending with .csv, got {csv_path}"
|
||||
assert Path(csv_path).parent.is_dir(
|
||||
), f"Cannot find parent directory for {Path(csv_path).parent}"
|
||||
assert csv_path.endswith("csv"), (
|
||||
f"Need a file path ending with .csv, got {csv_path}"
|
||||
)
|
||||
assert Path(csv_path).parent.is_dir(), (
|
||||
f"Cannot find parent directory for {Path(csv_path).parent}"
|
||||
)
|
||||
|
||||
make_feature_matrix(args.feature_matrix_csv_file_path)
|
||||
|
||||
@@ -8,24 +8,33 @@ import torch
|
||||
# Fused experts and PrepareFinalize imports
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
||||
TritonExperts)
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported)
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported)
|
||||
cutlass_fp8_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
@@ -60,8 +69,7 @@ class ExpertInfo:
|
||||
needs_deep_gemm: bool = False
|
||||
|
||||
|
||||
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
|
||||
PrepareFinalizeInfo] = {}
|
||||
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
|
||||
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
@@ -71,7 +79,10 @@ MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
||||
standard_format = mk.FusedMoEActivationFormat.Standard
|
||||
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
||||
common_float_types: list[Union[torch.dtype, str]] = [
|
||||
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
|
||||
torch.float8_e4m3fn,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
]
|
||||
common_float_and_int_types = common_float_types + [torch.int8]
|
||||
nvfp4_types = ["nvfp4"]
|
||||
@@ -186,9 +197,11 @@ register_experts(
|
||||
# Disable on blackwell for now
|
||||
if has_deep_ep() and not current_platform.has_device_capability(100):
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
@@ -208,7 +221,9 @@ if has_deep_ep() and not current_platform.has_device_capability(100):
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
PplxPrepareAndFinalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
PplxPrepareAndFinalize,
|
||||
batched_format,
|
||||
@@ -217,13 +232,14 @@ if has_pplx():
|
||||
backend="pplx",
|
||||
)
|
||||
|
||||
if (has_flashinfer_cutlass_fused_moe()
|
||||
and current_platform.has_device_capability(100)):
|
||||
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
FlashInferExperts)
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
create_flashinfer_prepare_finalize)
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
@@ -258,16 +274,18 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
DeepGemmExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
),
|
||||
(
|
||||
register_experts(
|
||||
DeepGemmExperts,
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
),
|
||||
)
|
||||
register_experts(
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
batched_format,
|
||||
@@ -290,8 +308,11 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
)
|
||||
|
||||
if cutlass_fp8_supported():
|
||||
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8)
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8,
|
||||
)
|
||||
|
||||
register_experts(
|
||||
CutlassExpertsFp8,
|
||||
standard_format,
|
||||
@@ -310,8 +331,8 @@ if cutlass_fp8_supported():
|
||||
)
|
||||
|
||||
if cutlass_fp4_supported():
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp4
|
||||
|
||||
register_experts(
|
||||
CutlassExpertsFp4,
|
||||
standard_format,
|
||||
@@ -324,30 +345,40 @@ if cutlass_fp4_supported():
|
||||
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
),
|
||||
# per-channel / per-column weights and per-token activations
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
),
|
||||
# per-tensor weights and per-tensor activations
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
),
|
||||
# per-tensor weights and per-token activations
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
),
|
||||
# block-quantized weights and 128 block per-token activations
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128]),
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128],
|
||||
),
|
||||
# TODO (varun) : Should we test the following combinations ?
|
||||
# block-quantized weights and per-token activations
|
||||
# block-quantized weights and per-tensor activations
|
||||
@@ -355,10 +386,12 @@ MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
|
||||
|
||||
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||
MK_QUANT_CONFIGS += [
|
||||
TestMoEQuantConfig(quant_dtype="nvfp4",
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(
|
||||
quant_dtype="nvfp4",
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -370,12 +403,14 @@ def make_prepare_finalize(
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
if backend != "naive" and backend is not None:
|
||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||
moe, quant_config)
|
||||
moe, quant_config
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
return prepare_finalize
|
||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1)
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1
|
||||
)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
@@ -391,10 +426,10 @@ def make_cutlass_strides(
|
||||
n: int,
|
||||
k: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
return ab_strides1, ab_strides2, c_strides1, c_strides2
|
||||
|
||||
|
||||
@@ -405,7 +440,6 @@ def make_fused_experts(
|
||||
num_dispatchers: int,
|
||||
N: int,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
|
||||
@@ -6,13 +6,11 @@ import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
## Parallel Processes Utils
|
||||
@@ -30,10 +28,11 @@ class ProcessGroupInfo:
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
|
||||
local_rank: int):
|
||||
|
||||
def _set_vllm_config(
|
||||
vllm_config: VllmConfig, world_size: int, rank: int, local_rank: int
|
||||
):
|
||||
import tempfile
|
||||
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
@@ -46,13 +45,10 @@ def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.
|
||||
tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.
|
||||
pipeline_parallel_size,
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)),
|
||||
backend="gloo")
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
|
||||
return cpu_group
|
||||
|
||||
|
||||
@@ -62,8 +58,7 @@ def _worker_parallel_launch(
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any,
|
||||
P], None],
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None],
|
||||
vllm_config: Optional[VllmConfig],
|
||||
env_dict: Optional[dict],
|
||||
*args: P.args,
|
||||
@@ -131,7 +126,8 @@ def parallel_launch_with_config(
|
||||
worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
) + args,
|
||||
)
|
||||
+ args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
@@ -14,28 +14,31 @@ from .common import Config, RankTensors, WeightTensors, make_modular_kernel
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
def do_profile(fn: Callable,
|
||||
fn_kwargs: dict[Any, Any],
|
||||
pgi: ProcessGroupInfo,
|
||||
config: Config,
|
||||
num_warmups: int = 5):
|
||||
def do_profile(
|
||||
fn: Callable,
|
||||
fn_kwargs: dict[Any, Any],
|
||||
pgi: ProcessGroupInfo,
|
||||
config: Config,
|
||||
num_warmups: int = 5,
|
||||
):
|
||||
for _ in range(num_warmups):
|
||||
fn(**fn_kwargs)
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=True,
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=True,
|
||||
) as tprof:
|
||||
fn(**fn_kwargs)
|
||||
torch.cuda.synchronize(torch.cuda.current_device())
|
||||
|
||||
# TODO (varun): Add a descriptive trace file name
|
||||
tprof.export_chrome_trace(
|
||||
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json")
|
||||
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json"
|
||||
)
|
||||
|
||||
|
||||
def profile_modular_kernel(
|
||||
@@ -82,6 +85,7 @@ def rank_worker(
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
@@ -108,20 +112,25 @@ def rank_worker(
|
||||
def run(config: Config):
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
||||
env_dict, config, weights)
|
||||
parallel_launch_with_config(
|
||||
config.world_size, rank_worker, vllm_config, env_dict, config, weights
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
from .cli_args import make_config, make_config_arg_parser
|
||||
parser = make_config_arg_parser(description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
))
|
||||
|
||||
parser = make_config_arg_parser(
|
||||
description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " # noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert args.torch_trace_dir_path is not None, (
|
||||
"Please pass in a directory to store torch traces")
|
||||
"Please pass in a directory to store torch traces"
|
||||
)
|
||||
config = make_config(args)
|
||||
|
||||
run(config)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
DeepEP test utilities
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
@@ -10,17 +11,18 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port, has_deep_ep
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
@@ -96,7 +98,8 @@ def parallel_launch(
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
) + args,
|
||||
)
|
||||
+ args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
@@ -118,48 +121,57 @@ class DeepEPLLArgs:
|
||||
use_fp8_dispatch: bool
|
||||
|
||||
|
||||
def make_deepep_ht_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
def make_deepep_ht_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
import deep_ep
|
||||
|
||||
# high throughput a2a
|
||||
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
||||
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=low_latency_mode,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
return DeepEPHTPrepareAndFinalize(buffer=buffer,
|
||||
num_dispatchers=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
rank_expert_offset=pgi.rank *
|
||||
ht_args.num_local_experts)
|
||||
buffer = deep_ep.Buffer(
|
||||
group=pg,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=low_latency_mode,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
)
|
||||
return DeepEPHTPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
num_dispatchers=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
rank_expert_offset=pgi.rank * ht_args.num_local_experts,
|
||||
)
|
||||
|
||||
|
||||
def make_deepep_ll_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
def make_deepep_ll_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
import deep_ep
|
||||
|
||||
# low-latency a2a
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
|
||||
pgi.world_size, deepep_ll_args.num_experts)
|
||||
deepep_ll_args.max_tokens_per_rank,
|
||||
deepep_ll_args.hidden_size,
|
||||
pgi.world_size,
|
||||
deepep_ll_args.num_experts,
|
||||
)
|
||||
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=deepep_ll_args.num_experts //
|
||||
pgi.world_size)
|
||||
buffer = deep_ep.Buffer(
|
||||
group=pg,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size,
|
||||
)
|
||||
|
||||
return DeepEPLLPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
@@ -169,17 +181,20 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
|
||||
)
|
||||
|
||||
|
||||
def make_deepep_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: Optional[DeepEPHTArgs],
|
||||
deepep_ll_args: Optional[DeepEPLLArgs],
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
def make_deepep_a2a(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: Optional[DeepEPHTArgs],
|
||||
deepep_ll_args: Optional[DeepEPLLArgs],
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
if deepep_ht_args is not None:
|
||||
assert deepep_ll_args is None
|
||||
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
|
||||
block_shape)
|
||||
return make_deepep_ht_a2a(
|
||||
pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape
|
||||
)
|
||||
|
||||
assert deepep_ll_args is not None
|
||||
return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape)
|
||||
|
||||
@@ -5,13 +5,14 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
|
||||
|
||||
from .test_deepgemm import make_block_quant_fp8_weights
|
||||
@@ -19,15 +20,15 @@ from .test_deepgemm import make_block_quant_fp8_weights
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(),
|
||||
reason="Requires deep_gemm kernels")
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
@pytest.mark.parametrize("E", [16, 32]) # number of experts
|
||||
@pytest.mark.parametrize("T", [256, 512]) # tokens per expert
|
||||
@pytest.mark.parametrize("K", [128, 256]) # hidden dim
|
||||
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||
monkeypatch):
|
||||
def test_batched_deepgemm_vs_triton(
|
||||
E: int, T: int, K: int, N: int, topk: int, monkeypatch
|
||||
):
|
||||
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
@@ -7,14 +7,18 @@ from typing import Optional
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import (batched_moe,
|
||||
make_quantized_test_activations,
|
||||
make_test_weights, naive_batched_moe)
|
||||
from tests.kernels.moe.utils import (
|
||||
batched_moe,
|
||||
make_quantized_test_activations,
|
||||
make_test_weights,
|
||||
naive_batched_moe,
|
||||
)
|
||||
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel)
|
||||
invoke_moe_batched_triton_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl
|
||||
@@ -68,23 +72,32 @@ class BatchedMMTensors:
|
||||
|
||||
@staticmethod
|
||||
def make_tensors(config: BatchedMMConfig):
|
||||
A = torch.randn(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||
A = (
|
||||
torch.randn(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||
device="cuda",
|
||||
dtype=config.in_dtype,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
B = torch.randn(
|
||||
(config.num_experts, config.N, config.K),
|
||||
device="cuda",
|
||||
dtype=config.in_dtype) / 10
|
||||
B = torch.randn((config.num_experts, config.N, config.K),
|
||||
device="cuda",
|
||||
dtype=config.in_dtype)
|
||||
dtype=config.in_dtype,
|
||||
)
|
||||
C = torch.zeros(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||
device="cuda",
|
||||
dtype=config.out_dtype)
|
||||
dtype=config.out_dtype,
|
||||
)
|
||||
|
||||
num_expert_tokens = torch.randint(low=0,
|
||||
high=config.max_tokens_per_expert,
|
||||
size=(config.num_experts, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
num_expert_tokens = torch.randint(
|
||||
low=0,
|
||||
high=config.max_tokens_per_expert,
|
||||
size=(config.num_experts,),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||
|
||||
@@ -96,10 +109,15 @@ class BatchedMMTensors:
|
||||
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
N: int, dtype: torch.dtype,
|
||||
block_shape: Optional[list[int]],
|
||||
per_act_token_quant: bool):
|
||||
def test_batched_mm(
|
||||
num_experts: int,
|
||||
max_tokens_per_expert: int,
|
||||
K: int,
|
||||
N: int,
|
||||
dtype: torch.dtype,
|
||||
block_shape: Optional[list[int]],
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
@@ -117,11 +135,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
num_expert_tokens = torch.randint(low=0,
|
||||
high=max_tokens_per_expert,
|
||||
size=(num_experts, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
num_expert_tokens = torch.randint(
|
||||
low=0,
|
||||
high=max_tokens_per_expert,
|
||||
size=(num_experts,),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
A, A_q, A_scale = make_quantized_test_activations(
|
||||
num_experts,
|
||||
@@ -151,7 +171,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
compute_tl_dtype = {
|
||||
torch.float16: tl.float16,
|
||||
torch.bfloat16: tl.bfloat16,
|
||||
torch.float32: tl.float32
|
||||
torch.float32: tl.float32,
|
||||
}[test_output.dtype]
|
||||
|
||||
assert A_q.dtype == B_q.dtype
|
||||
@@ -173,7 +193,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
config={
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
|
||||
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32,
|
||||
},
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
@@ -186,11 +206,16 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
num_expert_tokens,
|
||||
)
|
||||
|
||||
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
|
||||
num_expert_tokens,
|
||||
A_scale, B_scale,
|
||||
block_shape,
|
||||
per_act_token_quant)
|
||||
q_ref_output = native_batched_masked_quant_matmul(
|
||||
A_q,
|
||||
B_q,
|
||||
q_ref_output,
|
||||
num_expert_tokens,
|
||||
A_scale,
|
||||
B_scale,
|
||||
block_shape,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
@@ -308,12 +333,6 @@ def test_fused_moe_batched_experts(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(batched_output,
|
||||
baseline_output,
|
||||
atol=3e-2,
|
||||
rtol=2e-2)
|
||||
torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
batched_output,
|
||||
atol=2e-2,
|
||||
rtol=2e-2)
|
||||
torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2)
|
||||
|
||||
@@ -5,15 +5,21 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
||||
_valid_deep_gemm_shape,
|
||||
deep_gemm_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
fused_topk,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
@@ -24,8 +30,7 @@ if dg_available:
|
||||
from deep_gemm import get_m_alignment_for_contiguous_layout
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
@@ -97,8 +102,7 @@ TOP_KS = [1, 2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
|
||||
block_shape):
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
|
||||
"""Fused moe with block-wise quantization using native torch."""
|
||||
B, D = a.shape
|
||||
topk = topk_ids.size(1)
|
||||
@@ -114,23 +118,17 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
inter_out = native_w8a8_block_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||
act_out, block_k)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
|
||||
out[mask] = native_w8a8_block_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
@@ -149,8 +147,9 @@ def setup_cuda():
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
monkeypatch):
|
||||
def test_w8a8_block_fp8_fused_moe(
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch
|
||||
):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
@@ -188,12 +187,9 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
block_size,
|
||||
)
|
||||
|
||||
out = fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
out = fused_experts(
|
||||
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
|
||||
|
||||
@@ -210,8 +206,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||
|
||||
@@ -245,36 +240,38 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
|
||||
and current_platform.is_cuda_alike())
|
||||
use_cudagraph = (
|
||||
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids, block_size)
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size
|
||||
)
|
||||
|
||||
if use_compile:
|
||||
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
||||
backend="inductor",
|
||||
fullgraph=True)
|
||||
deep_gemm_moe_fp8_fn = torch.compile(
|
||||
deep_gemm_moe_fp8, backend="inductor", fullgraph=True
|
||||
)
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||
else:
|
||||
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
|
||||
if use_cudagraph:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
out = deep_gemm_moe_fp8_fn(
|
||||
a, w1, w2, w1_s, w2_s, topk_weights, topk_ids
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@@ -5,16 +5,17 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul)
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||
allow_module_level=True)
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
@@ -77,24 +78,18 @@ def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
inter_out = native_w8a8_block_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(
|
||||
act_out, block_k)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
out[mask] = native_w8a8_block_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
@@ -131,15 +126,19 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
|
||||
quant_config.w2_scale, score, topk,
|
||||
block_size)
|
||||
out = fused_experts(
|
||||
a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
quant_config.w1_scale,
|
||||
quant_config.w2_scale,
|
||||
score,
|
||||
topk,
|
||||
block_size,
|
||||
)
|
||||
|
||||
# Check results
|
||||
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)
|
||||
|
||||
@@ -15,7 +15,6 @@ from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -25,32 +24,31 @@ class TestTensors:
|
||||
self.expert_map = self.expert_map.to(device=device)
|
||||
|
||||
@staticmethod
|
||||
def make(num_tokens: int, num_topk: int, num_experts: int, device: str,
|
||||
topk_ids_dtype: torch.dtype) -> "TestTensors":
|
||||
|
||||
def make(
|
||||
num_tokens: int,
|
||||
num_topk: int,
|
||||
num_experts: int,
|
||||
device: str,
|
||||
topk_ids_dtype: torch.dtype,
|
||||
) -> "TestTensors":
|
||||
# make topk ids
|
||||
topk_ids = torch.empty((num_tokens, num_topk),
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
topk_ids = torch.empty((num_tokens, num_topk), device=device, dtype=torch.int64)
|
||||
for x in range(num_tokens):
|
||||
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
|
||||
topk_ids = topk_ids.to(dtype=torch.int64)
|
||||
return TestTensors(topk_ids=topk_ids)
|
||||
|
||||
def with_ep_rank(self, ep_rank: int, num_global_experts: int,
|
||||
num_local_experts: int, device: str):
|
||||
def with_ep_rank(
|
||||
self, ep_rank: int, num_global_experts: int, num_local_experts: int, device: str
|
||||
):
|
||||
# make an expert map
|
||||
expert_map = torch.empty((num_global_experts),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.empty((num_global_experts), device=device, dtype=torch.int32)
|
||||
expert_map.fill_(-1)
|
||||
s = ep_rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)),
|
||||
device=device)
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)), device=device)
|
||||
|
||||
return TestTensors(topk_ids=self.topk_ids.clone(),
|
||||
expert_map=expert_map)
|
||||
return TestTensors(topk_ids=self.topk_ids.clone(), expert_map=expert_map)
|
||||
|
||||
|
||||
def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
|
||||
@@ -68,49 +66,49 @@ def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
|
||||
expert_num_tokens[eid] += count
|
||||
|
||||
|
||||
def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
||||
num_experts: int, ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
|
||||
def do_test_compute_expert_num_tokens(
|
||||
num_tokens: int,
|
||||
num_topk: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype,
|
||||
):
|
||||
assert num_topk <= num_experts
|
||||
|
||||
tt = TestTensors.make(num_tokens,
|
||||
num_topk,
|
||||
num_experts,
|
||||
topk_ids_dtype=topk_ids_dtype,
|
||||
device="cpu")
|
||||
tt = TestTensors.make(
|
||||
num_tokens, num_topk, num_experts, topk_ids_dtype=topk_ids_dtype, device="cpu"
|
||||
)
|
||||
|
||||
num_global_experts = num_experts
|
||||
assert num_global_experts % ep_size == 0
|
||||
num_local_experts = num_global_experts // ep_size
|
||||
for ep_rank in range(ep_size):
|
||||
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts,
|
||||
num_local_experts, "cpu")
|
||||
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, num_local_experts, "cpu")
|
||||
|
||||
ref_expert_num_tokens = torch.zeros((num_local_experts),
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
ref_expert_num_tokens = torch.zeros(
|
||||
(num_local_experts), device="cpu", dtype=torch.int32
|
||||
)
|
||||
ref_impl(tt_rank, ref_expert_num_tokens)
|
||||
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")
|
||||
|
||||
tt_rank.to_device("cuda")
|
||||
# Test with expert_map
|
||||
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
|
||||
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map)
|
||||
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map
|
||||
)
|
||||
|
||||
# Test without expert map
|
||||
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
|
||||
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
|
||||
topk_ids, num_local_experts, expert_map=None)
|
||||
topk_ids, num_local_experts, expert_map=None
|
||||
)
|
||||
|
||||
torch.testing.assert_close(ref_expert_num_tokens,
|
||||
triton_expert_num_tokens_w_emap,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(ref_expert_num_tokens,
|
||||
triton_expert_num_tokens_wo_emap,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(
|
||||
ref_expert_num_tokens, triton_expert_num_tokens_w_emap, atol=0, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
ref_expert_num_tokens, triton_expert_num_tokens_wo_emap, atol=0, rtol=0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317])
|
||||
@@ -118,22 +116,29 @@ def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
||||
num_experts: int, ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts,
|
||||
ep_size, topk_ids_dtype)
|
||||
def test_compute_expert_num_tokens(
|
||||
num_tokens: int,
|
||||
num_topk: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype,
|
||||
):
|
||||
do_test_compute_expert_num_tokens(
|
||||
num_tokens, num_topk, num_experts, ep_size, topk_ids_dtype
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("numel", list(range(1, 8192, 111)))
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("ep_size", [2])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
do_test_compute_expert_num_tokens(num_tokens=numel,
|
||||
num_topk=1,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
topk_ids_dtype=topk_ids_dtype)
|
||||
def test_compute_expert_num_tokens_from_numel(
|
||||
numel: int, num_experts: int, ep_size: int, topk_ids_dtype: torch.dtype
|
||||
):
|
||||
do_test_compute_expert_num_tokens(
|
||||
num_tokens=numel,
|
||||
num_topk=1,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
topk_ids_dtype=topk_ids_dtype,
|
||||
)
|
||||
|
||||
@@ -17,19 +17,24 @@ from vllm.utils import cdiv
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
|
||||
(4, 8192, 7168, 4096),
|
||||
(4, 8192, 2048, 7168),
|
||||
(8, 4096, 7168, 4096),
|
||||
(8, 4096, 2048, 7168),
|
||||
(32, 1024, 7168, 4096),
|
||||
(32, 1024, 2048, 7168),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups, expected_m_per_group, k, n",
|
||||
[
|
||||
(4, 8192, 7168, 4096),
|
||||
(4, 8192, 2048, 7168),
|
||||
(8, 4096, 7168, 4096),
|
||||
(8, 4096, 2048, 7168),
|
||||
(32, 1024, 7168, 4096),
|
||||
(32, 1024, 2048, 7168),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or x.to_int() != 100)(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Block Scaled Grouped GEMM is only supported on SM100.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Block Scaled Grouped GEMM is only supported on SM100.",
|
||||
)
|
||||
def test_cutlass_grouped_gemm(
|
||||
num_groups: int,
|
||||
expected_m_per_group: int,
|
||||
@@ -40,8 +45,7 @@ def test_cutlass_grouped_gemm(
|
||||
device = "cuda"
|
||||
alignment = 128
|
||||
group_ms = [
|
||||
int(expected_m_per_group * random.uniform(0.7, 1.3))
|
||||
for _ in range(num_groups)
|
||||
int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)
|
||||
]
|
||||
m = sum([cdiv(m, alignment) * alignment for m in group_ms])
|
||||
|
||||
@@ -58,20 +62,22 @@ def test_cutlass_grouped_gemm(
|
||||
expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32)
|
||||
|
||||
x_fp8 = per_token_cast_to_fp8(x)
|
||||
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn),
|
||||
torch.empty((num_groups, cdiv(n, 128), k // 128),
|
||||
device=device,
|
||||
dtype=torch.float))
|
||||
y_fp8 = (
|
||||
torch.empty_like(y, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float
|
||||
),
|
||||
)
|
||||
for i in range(num_groups):
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
|
||||
|
||||
for i in range(num_groups):
|
||||
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
|
||||
a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]]
|
||||
a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]]
|
||||
a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]]
|
||||
b = y_fp8[0][i].t()
|
||||
b_scale = y_fp8[1][i].t()
|
||||
baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype)
|
||||
ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline
|
||||
ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline
|
||||
|
||||
ops.cutlass_blockwise_scaled_grouped_mm(
|
||||
out,
|
||||
|
||||
@@ -11,13 +11,15 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8, run_cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||
fused_topk)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
cutlass_moe_fp8,
|
||||
run_cutlass_moe_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
@@ -39,12 +41,11 @@ MNK_FACTORS = [
|
||||
(224, 3072, 1536),
|
||||
(32768, 1024, 1024),
|
||||
# These sizes trigger wrong answers.
|
||||
#(7232, 2048, 5120),
|
||||
#(40000, 2048, 5120),
|
||||
# (7232, 2048, 5120),
|
||||
# (40000, 2048, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
@@ -60,22 +61,25 @@ class MOETensors:
|
||||
c_strides2: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors(m: int, k: int, n: int, e: int,
|
||||
dtype: torch.dtype) -> "MOETensors":
|
||||
def make_moe_tensors(
|
||||
m: int, k: int, n: int, e: int, dtype: torch.dtype
|
||||
) -> "MOETensors":
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
return MOETensors(a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
ab_strides1=ab_strides1,
|
||||
c_strides1=c_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides2=c_strides2)
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
return MOETensors(
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
ab_strides1=ab_strides1,
|
||||
c_strides1=c_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides2=c_strides2,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -93,9 +97,9 @@ class MOETensors8Bit(MOETensors):
|
||||
w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool) -> "MOETensors8Bit":
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool
|
||||
) -> "MOETensors8Bit":
|
||||
dtype = torch.half
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
@@ -106,24 +110,21 @@ class MOETensors8Bit(MOETensors):
|
||||
k_b_scales = k if per_out_channel else 1
|
||||
# Get the right scale for tests.
|
||||
a_q, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
|
||||
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token
|
||||
)
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w1[expert],
|
||||
use_per_token_if_dynamic=per_out_channel)
|
||||
moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w2[expert],
|
||||
use_per_token_if_dynamic=per_out_channel)
|
||||
moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel
|
||||
)
|
||||
|
||||
# a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
|
||||
a_d = a_q.float().mul(a_scale).to(dtype)
|
||||
@@ -133,31 +134,37 @@ class MOETensors8Bit(MOETensors):
|
||||
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||
|
||||
return MOETensors8Bit(a=moe_tensors_fp16.a,
|
||||
w1=moe_tensors_fp16.w1,
|
||||
w2=moe_tensors_fp16.w2,
|
||||
ab_strides1=moe_tensors_fp16.ab_strides1,
|
||||
c_strides1=moe_tensors_fp16.c_strides1,
|
||||
ab_strides2=moe_tensors_fp16.ab_strides2,
|
||||
c_strides2=moe_tensors_fp16.c_strides2,
|
||||
a_q=a_q,
|
||||
w1_q=w1_q,
|
||||
w2_q=w2_q,
|
||||
a_scale=a_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a_d=a_d,
|
||||
w1_d=w1_d,
|
||||
w2_d=w2_d)
|
||||
return MOETensors8Bit(
|
||||
a=moe_tensors_fp16.a,
|
||||
w1=moe_tensors_fp16.w1,
|
||||
w2=moe_tensors_fp16.w2,
|
||||
ab_strides1=moe_tensors_fp16.ab_strides1,
|
||||
c_strides1=moe_tensors_fp16.c_strides1,
|
||||
ab_strides2=moe_tensors_fp16.ab_strides2,
|
||||
c_strides2=moe_tensors_fp16.c_strides2,
|
||||
a_q=a_q,
|
||||
w1_q=w1_q,
|
||||
w2_q=w2_q,
|
||||
a_scale=a_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a_d=a_d,
|
||||
w1_d=w1_d,
|
||||
w2_d=w2_d,
|
||||
)
|
||||
|
||||
|
||||
def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
**cutlass_moe_kwargs):
|
||||
|
||||
def run_with_expert_maps(
|
||||
num_experts: int, num_local_experts: int, **cutlass_moe_kwargs
|
||||
):
|
||||
def slice_experts():
|
||||
slice_params = [
|
||||
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
|
||||
"c_strides2"
|
||||
"w1_q",
|
||||
"w2_q",
|
||||
"ab_strides1",
|
||||
"ab_strides2",
|
||||
"c_strides1",
|
||||
"c_strides2",
|
||||
]
|
||||
full_tensors = {
|
||||
k: v
|
||||
@@ -173,9 +180,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
# make expert map
|
||||
expert_map = [-1] * num_experts
|
||||
expert_map[s:e] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
# update cutlass moe arg with expert_map
|
||||
cutlass_moe_kwargs["expert_map"] = expert_map
|
||||
@@ -198,18 +203,26 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
return out_tensor
|
||||
|
||||
|
||||
def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||
assert not any([
|
||||
t is None for t in [
|
||||
moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale,
|
||||
moe_tensors.w2_scale, moe_tensors.a_scale
|
||||
def run_8_bit(
|
||||
moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
num_local_experts: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
assert not any(
|
||||
[
|
||||
t is None
|
||||
for t in [
|
||||
moe_tensors.w1_q,
|
||||
moe_tensors.w2_q,
|
||||
moe_tensors.w1_scale,
|
||||
moe_tensors.w2_scale,
|
||||
moe_tensors.a_scale,
|
||||
]
|
||||
]
|
||||
])
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=moe_tensors.w1_scale,
|
||||
@@ -222,16 +235,16 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
'a': moe_tensors.a,
|
||||
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
'topk_weights': topk_weights,
|
||||
'topk_ids': topk_ids,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'c_strides2': moe_tensors.c_strides2,
|
||||
'quant_config': quant_config,
|
||||
"a": moe_tensors.a,
|
||||
"w1_q": moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
"w2_q": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"ab_strides1": moe_tensors.ab_strides1,
|
||||
"ab_strides2": moe_tensors.ab_strides2,
|
||||
"c_strides1": moe_tensors.c_strides1,
|
||||
"c_strides2": moe_tensors.c_strides2,
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
@@ -243,7 +256,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
return run_with_expert_maps(
|
||||
num_experts,
|
||||
num_local_experts, # type: ignore[arg-type]
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@@ -253,8 +267,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -269,25 +285,18 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(mt.a_d,
|
||||
mt.w1_d,
|
||||
mt.w2_d,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
triton_output = fused_experts(
|
||||
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
if ep_size is not None:
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
@@ -295,15 +304,15 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
else:
|
||||
number_local_experts = None
|
||||
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
|
||||
per_out_ch, number_local_experts)
|
||||
cutlass_output = run_8_bit(
|
||||
mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts
|
||||
)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(
|
||||
triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@@ -313,8 +322,10 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_cuda_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -330,39 +341,30 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(mt.a_d,
|
||||
mt.w1_d,
|
||||
mt.w2_d,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
triton_output = fused_experts(
|
||||
mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
|
||||
)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
||||
per_act_token, per_out_ch)
|
||||
cutlass_output = run_8_bit(
|
||||
mt, topk_weights, topk_ids, per_act_token, per_out_ch
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=9e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64])
|
||||
@@ -375,8 +377,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_EP(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -388,8 +392,9 @@ def test_cutlass_moe_8_bit_EP(
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
|
||||
per_out_channel, monkeypatch, ep_size)
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
)
|
||||
|
||||
|
||||
LARGE_MNK_FACTORS = [
|
||||
@@ -406,8 +411,10 @@ LARGE_MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_moe_8_bit_EP_large(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -419,8 +426,9 @@ def test_cutlass_moe_8_bit_EP_large(
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
|
||||
per_out_channel, monkeypatch, ep_size)
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
|
||||
@@ -430,8 +438,10 @@ def test_cutlass_moe_8_bit_EP_large(
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_run_cutlass_moe_fp8(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -444,14 +454,12 @@ def test_run_cutlass_moe_fp8(
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_channel)
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(
|
||||
m, k, n, e, per_act_token, per_out_channel
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
|
||||
# we want to make sure there is at least one token that's generated in
|
||||
# this expert shard and at least one token that's NOT generated in this
|
||||
# expert shard
|
||||
@@ -462,12 +470,12 @@ def test_run_cutlass_moe_fp8(
|
||||
workspace2_shape = (m * topk, max(n, k))
|
||||
output_shape = (m, k)
|
||||
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
workspace13 = torch.empty(
|
||||
prod(workspace13_shape), device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
workspace2 = torch.empty(
|
||||
prod(workspace2_shape), device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
|
||||
num_local_experts = e // ep_size
|
||||
start, end = 0, num_local_experts
|
||||
@@ -475,36 +483,55 @@ def test_run_cutlass_moe_fp8(
|
||||
expert_map[start:end] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
|
||||
)
|
||||
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
|
||||
func = lambda output: run_cutlass_moe_fp8(
|
||||
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
|
||||
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
|
||||
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
|
||||
workspace13, workspace2, None, mt.a.dtype, per_act_token,
|
||||
per_out_channel, False, topk_weights)
|
||||
output,
|
||||
a1q,
|
||||
mt.w1_q,
|
||||
mt.w2_q,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
mt.w1_scale,
|
||||
mt.w2_scale,
|
||||
a1q_scale,
|
||||
None,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
workspace13,
|
||||
workspace2,
|
||||
None,
|
||||
mt.a.dtype,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
False,
|
||||
topk_weights,
|
||||
)
|
||||
|
||||
workspace13.random_()
|
||||
output_random_workspace = torch.empty(output_shape,
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
output_random_workspace = torch.empty(
|
||||
output_shape, device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
func(output_random_workspace)
|
||||
|
||||
workspace13.fill_(0)
|
||||
output_zero_workspace = torch.zeros(output_shape,
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
output_zero_workspace = torch.zeros(
|
||||
output_shape, device="cuda", dtype=mt.a.dtype
|
||||
)
|
||||
func(output_zero_workspace)
|
||||
|
||||
torch.testing.assert_close(output_random_workspace,
|
||||
output_zero_workspace,
|
||||
atol=5e-3,
|
||||
rtol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3
|
||||
)
|
||||
|
||||
@@ -16,10 +16,11 @@ from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
|
||||
@@ -30,18 +31,19 @@ from .utils import make_test_weights
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
if has_deep_gemm():
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts)
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep(),
|
||||
@@ -58,9 +60,10 @@ P = ParamSpec("P")
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2**math.ceil(math.log2(x))
|
||||
return 2 ** math.ceil(math.log2(x))
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
@@ -72,13 +75,9 @@ def make_block_quant_fp8_weights(
|
||||
"""
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
|
||||
_) = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
block_shape=block_size)
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
|
||||
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
|
||||
)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@@ -106,15 +105,15 @@ class TestTensors:
|
||||
|
||||
@staticmethod
|
||||
def make(config: TestConfig, rank) -> "TestTensors":
|
||||
|
||||
dtype = torch.bfloat16
|
||||
topk, m, k = (config.topk, config.m, config.k)
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
rank_tokens = torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
rank_tokens = (
|
||||
torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
)
|
||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||
rank_token_scales = None
|
||||
|
||||
@@ -122,25 +121,32 @@ class TestTensors:
|
||||
low=0,
|
||||
high=config.num_experts,
|
||||
size=(m, topk),
|
||||
device=torch.cuda.current_device()).to(dtype=torch.int64)
|
||||
device=torch.cuda.current_device(),
|
||||
).to(dtype=torch.int64)
|
||||
|
||||
topk_weights = torch.randn(topk_ids.shape,
|
||||
dtype=torch.float32,
|
||||
device=torch.cuda.current_device())
|
||||
topk_weights = torch.randn(
|
||||
topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
return TestTensors(rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
config=config)
|
||||
return TestTensors(
|
||||
rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def make_ll_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
|
||||
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
max_tokens_per_rank: int,
|
||||
dp_size: int,
|
||||
hidden_size: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
assert test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is not None
|
||||
|
||||
@@ -153,26 +159,30 @@ def make_ll_modular_kernel(
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=test_config.num_experts,
|
||||
use_fp8_dispatch=test_config.use_fp8_dispatch),
|
||||
use_fp8_dispatch=test_config.use_fp8_dispatch,
|
||||
),
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
num_dispatchers=pgi.world_size // dp_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_ht_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
assert not test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is None
|
||||
|
||||
@@ -183,76 +193,82 @@ def make_ht_modular_kernel(
|
||||
deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
|
||||
deepep_ll_args=None,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
fused_experts = DeepGemmExperts(quant_config)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
num_local_experts: int,
|
||||
test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
test_config = test_tensors.config
|
||||
|
||||
mk: FusedMoEModularKernel
|
||||
# Make modular kernel
|
||||
if test_config.low_latency:
|
||||
max_tokens_per_rank = max(
|
||||
64, next_power_of_2(test_tensors.rank_tokens.size(0)))
|
||||
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
|
||||
hidden_size = test_tensors.rank_tokens.size(-1)
|
||||
|
||||
mk = make_ll_modular_kernel(pg=pg,
|
||||
pgi=pgi,
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
dp_size=dp_size,
|
||||
hidden_size=hidden_size,
|
||||
q_dtype=q_dtype,
|
||||
test_config=test_config,
|
||||
quant_config=quant_config)
|
||||
mk = make_ll_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
max_tokens_per_rank=max_tokens_per_rank,
|
||||
dp_size=dp_size,
|
||||
hidden_size=hidden_size,
|
||||
q_dtype=q_dtype,
|
||||
test_config=test_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
mk = make_ht_modular_kernel(pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
test_config,
|
||||
quant_config=quant_config)
|
||||
mk = make_ht_modular_kernel(
|
||||
pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
test_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
return mk
|
||||
|
||||
|
||||
def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
dp_size: int, test_tensors: TestTensors,
|
||||
w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
def deepep_deepgemm_moe_impl(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
test_config = test_tensors.config
|
||||
num_experts = test_config.num_experts
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
def build_expert_map():
|
||||
num_local_experts = w1.size(0)
|
||||
expert_map = torch.full((num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale=(None if test_config.low_latency else
|
||||
test_tensors.rank_token_scales),
|
||||
a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
@@ -263,26 +279,35 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
dp_size=dp_size,
|
||||
num_local_experts=num_local_experts,
|
||||
test_tensors=test_tensors,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False)
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor, block_shape: list[int]):
|
||||
|
||||
def triton_impl(
|
||||
a: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
block_shape: list[int],
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -300,7 +325,8 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
quant_config=quant_config,
|
||||
# Make sure this is set to False so we
|
||||
# don't end up comparing the same implementation.
|
||||
allow_deep_gemm=False)
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
|
||||
def _test_deepep_deepgemm_moe(
|
||||
@@ -321,22 +347,21 @@ def _test_deepep_deepgemm_moe(
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, pgi.rank)
|
||||
block_shape = [
|
||||
w1.size(1) // w1_scale.size(1),
|
||||
w1.size(2) // w1_scale.size(2)
|
||||
]
|
||||
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
triton_moe = triton_impl(a=test_tensors.rank_tokens,
|
||||
topk_ids=test_tensors.topk,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=test_tensors.rank_token_scales,
|
||||
block_shape=block_shape)
|
||||
triton_moe = triton_impl(
|
||||
a=test_tensors.rank_tokens,
|
||||
topk_ids=test_tensors.topk,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=test_tensors.rank_token_scales,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Slice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
@@ -390,10 +415,15 @@ NUM_EXPERTS = [32]
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
topk: int, world_dp_size: tuple[int, int]):
|
||||
@pytest.mark.skipif(
|
||||
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
|
||||
)
|
||||
def test_ht_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
"""
|
||||
@@ -409,21 +439,32 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
block_size = [block_m, block_m]
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=False,
|
||||
use_fp8_dispatch=None)
|
||||
config = TestConfig(
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=False,
|
||||
use_fp8_dispatch=None,
|
||||
)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
num_experts, n, k, block_size)
|
||||
num_experts, n, k, block_size
|
||||
)
|
||||
|
||||
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
|
||||
w2, w1_scale, w2_scale)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_test_deepep_deepgemm_moe,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
@@ -448,8 +489,9 @@ USE_FP8_DISPATCH = [False]
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
|
||||
reason="Skipping test for Blackwell DeepGEMM")
|
||||
@pytest.mark.skipif(
|
||||
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
|
||||
)
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
@@ -482,7 +524,16 @@ def test_ll_deepep_deepgemm_moe(
|
||||
)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
num_experts, n, k, block_size)
|
||||
num_experts, n, k, block_size
|
||||
)
|
||||
|
||||
parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
|
||||
w2, w1_scale, w2_scale)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_test_deepep_deepgemm_moe,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
)
|
||||
|
||||
@@ -16,12 +16,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep
|
||||
|
||||
@@ -30,9 +29,11 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
@@ -45,7 +46,7 @@ MAX_TOKENS_PER_RANK = 64
|
||||
|
||||
|
||||
def make_weights(
|
||||
e, n, k, dtype
|
||||
e, n, k, dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1_scale, w2_scale
|
||||
@@ -64,17 +65,15 @@ def make_weights(
|
||||
k_b_scales = k
|
||||
w1_q = torch.empty_like(w1, dtype=dtype)
|
||||
w2_q = torch.empty_like(w2, dtype=dtype)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=True)
|
||||
w1[expert], use_per_token_if_dynamic=True
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=True)
|
||||
w2[expert], use_per_token_if_dynamic=True
|
||||
)
|
||||
return w1_q, w2_q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@@ -100,24 +99,25 @@ class TestTensors:
|
||||
def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors":
|
||||
# TODO (varun) - check that float16 works ?
|
||||
assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn]
|
||||
token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn
|
||||
else config.dtype)
|
||||
rank_tokens = torch.randn(
|
||||
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
||||
token_dtype = (
|
||||
torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype
|
||||
)
|
||||
rank_tokens = (
|
||||
torch.randn((config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
||||
)
|
||||
rank_token_scales = None
|
||||
|
||||
topk = torch.randint(low=0,
|
||||
high=config.num_experts,
|
||||
size=(config.m, config.topk),
|
||||
device="cuda").to(dtype=torch.int64)
|
||||
topk_weights = torch.randn(topk.shape,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
return TestTensors(rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk,
|
||||
topk_weights=topk_weights,
|
||||
config=config)
|
||||
topk = torch.randint(
|
||||
low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda"
|
||||
).to(dtype=torch.int64)
|
||||
topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda")
|
||||
return TestTensors(
|
||||
rank_tokens=rank_tokens,
|
||||
rank_token_scales=rank_token_scales,
|
||||
topk=topk,
|
||||
topk_weights=topk_weights,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
@@ -132,28 +132,33 @@ def make_modular_kernel(
|
||||
use_fp8_dispatch: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
|
||||
ht_args: Optional[DeepEPHTArgs] = None
|
||||
ll_args: Optional[DeepEPLLArgs] = None
|
||||
|
||||
if low_latency_mode:
|
||||
ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
use_fp8_dispatch=use_fp8_dispatch)
|
||||
ll_args = DeepEPLLArgs(
|
||||
max_tokens_per_rank=MAX_TOKENS_PER_RANK,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
else:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 Dispatch is valid only for low-latency kernels")
|
||||
"FP8 Dispatch is valid only for low-latency kernels"
|
||||
)
|
||||
ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)
|
||||
|
||||
a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \
|
||||
make_deepep_a2a(pg = pg,
|
||||
pgi = pgi,
|
||||
dp_size = dp_size,
|
||||
q_dtype = q_dtype,
|
||||
block_shape = None,
|
||||
deepep_ht_args = ht_args,
|
||||
deepep_ll_args = ll_args)
|
||||
a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = (
|
||||
make_deepep_a2a(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=None,
|
||||
deepep_ht_args=ht_args,
|
||||
deepep_ll_args=ll_args,
|
||||
)
|
||||
)
|
||||
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
|
||||
@@ -167,8 +172,7 @@ def make_modular_kernel(
|
||||
else:
|
||||
fused_experts = TritonExperts(quant_config=quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
@@ -186,19 +190,15 @@ def deep_ep_moe_impl(
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
) -> torch.Tensor:
|
||||
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
def build_expert_map():
|
||||
num_local_experts = w1.size(0)
|
||||
expert_map = torch.full((num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
return expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
|
||||
|
||||
hidden_size = test_tensors.rank_tokens.size(1)
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
@@ -214,11 +214,12 @@ def deep_ep_moe_impl(
|
||||
topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end]
|
||||
topk_chunk = test_tensors.topk[chunk_start:chunk_end]
|
||||
rank_token_scales_chunk = test_tensors.rank_token_scales
|
||||
if rank_token_scales_chunk is not None and rank_token_scales_chunk.size(
|
||||
0) == total_num_tokens:
|
||||
if (
|
||||
rank_token_scales_chunk is not None
|
||||
and rank_token_scales_chunk.size(0) == total_num_tokens
|
||||
):
|
||||
# per act token
|
||||
rank_token_scales_chunk = rank_token_scales_chunk[
|
||||
chunk_start:chunk_end]
|
||||
rank_token_scales_chunk = rank_token_scales_chunk[chunk_start:chunk_end]
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
q_dtype,
|
||||
@@ -230,26 +231,37 @@ def deep_ep_moe_impl(
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
|
||||
pg,
|
||||
pgi,
|
||||
low_latency_mode,
|
||||
hidden_size,
|
||||
dp_size,
|
||||
num_experts,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
use_fp8_dispatch,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
out = mk.forward(hidden_states=rank_tokens_chunk,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights_chunk,
|
||||
topk_ids=topk_chunk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False)
|
||||
out = mk.forward(
|
||||
hidden_states=rank_tokens_chunk,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights_chunk,
|
||||
topk_ids=topk_chunk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
if not skip_result_store:
|
||||
out_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
out, non_blocking=True)
|
||||
out_hidden_states[chunk_start:chunk_end, :].copy_(out, non_blocking=True)
|
||||
|
||||
max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK
|
||||
if low_latency_mode else total_num_tokens)
|
||||
max_num_tokens_per_dp = (
|
||||
MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens
|
||||
)
|
||||
|
||||
for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp):
|
||||
chunk_start = chunk_start_
|
||||
@@ -258,9 +270,9 @@ def deep_ep_moe_impl(
|
||||
chunk_start = min(chunk_start, total_num_tokens - 1)
|
||||
chunk_end = min(chunk_end, total_num_tokens)
|
||||
|
||||
process_chunk(chunk_start,
|
||||
chunk_end,
|
||||
skip_result_store=chunk_start_ >= total_num_tokens)
|
||||
process_chunk(
|
||||
chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens
|
||||
)
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
@@ -274,9 +286,11 @@ def torch_moe_impl(
|
||||
using_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
|
||||
test_tensors.topk_weights)
|
||||
a, topk_ids, topk_weights = (
|
||||
test_tensors.rank_tokens,
|
||||
test_tensors.topk,
|
||||
test_tensors.topk_weights,
|
||||
)
|
||||
if using_fp8_dispatch:
|
||||
# The DeepEP implementation is requested to dispatch using FP8.
|
||||
# For numerical stability for testing, emulate the fp8 dispatch by
|
||||
@@ -284,8 +298,11 @@ def torch_moe_impl(
|
||||
assert not per_act_token_quant
|
||||
a = test_tensors.rank_tokens
|
||||
aq, aq_scale = per_token_group_quant_fp8(a, 128)
|
||||
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
|
||||
a.shape).to(a.dtype)
|
||||
a = (
|
||||
(aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
|
||||
.view(a.shape)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
a_dtype = a.dtype
|
||||
@@ -306,8 +323,9 @@ def torch_moe_impl(
|
||||
e_w = topk_weights[i][j]
|
||||
w1_e = w1[e]
|
||||
w2_e = w2[e]
|
||||
o_i += (SiluAndMul()
|
||||
(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w
|
||||
o_i += (
|
||||
SiluAndMul()(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)
|
||||
) * e_w
|
||||
|
||||
if is_quantized:
|
||||
out = out.to(dtype=a_dtype)
|
||||
@@ -327,28 +345,36 @@ def _deep_ep_moe(
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
if not low_latency_mode:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 dispatch interface is available only in low-latency mode")
|
||||
"FP8 dispatch interface is available only in low-latency mode"
|
||||
)
|
||||
|
||||
is_quantized = w1.dtype == torch.float8_e4m3fn
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
w2 = w2.to(device=torch.cuda.current_device())
|
||||
if is_quantized:
|
||||
w1_scale = w1_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device())
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
w2_scale = w2_scale.to( # type: ignore
|
||||
device=torch.cuda.current_device())
|
||||
device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, low_latency_mode)
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
|
||||
w2_scale, use_fp8_dispatch,
|
||||
per_act_token_quant)
|
||||
torch_combined = torch_moe_impl(
|
||||
test_tensors,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
# Splice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
@@ -420,18 +446,23 @@ def test_deep_ep_moe(
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(dtype=dtype,
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts)
|
||||
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||
per_act_token_quant)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_deep_ep_moe,
|
||||
low_latency_mode,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
|
||||
MNKs = [
|
||||
@@ -467,8 +498,7 @@ def test_low_latency_deep_ep_moe(
|
||||
):
|
||||
low_latency_mode = True
|
||||
|
||||
if (low_latency_mode
|
||||
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
|
||||
if low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES:
|
||||
pytest.skip(
|
||||
f"Skipping test as hidden size {k} is not in list of supported "
|
||||
f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
|
||||
@@ -476,15 +506,20 @@ def test_low_latency_deep_ep_moe(
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
config = TestConfig(dtype=dtype,
|
||||
topk=topk,
|
||||
m=m,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts)
|
||||
config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
|
||||
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||
False)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_deep_ep_moe,
|
||||
low_latency_mode,
|
||||
dp_size,
|
||||
config,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
use_fp8_dispatch,
|
||||
False,
|
||||
)
|
||||
|
||||
@@ -11,14 +11,18 @@ import math
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported,
|
||||
per_block_cast_to_fp8)
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.utils.deep_gemm import (
|
||||
calc_diff,
|
||||
is_deep_gemm_supported,
|
||||
per_block_cast_to_fp8,
|
||||
)
|
||||
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
@@ -37,8 +41,10 @@ def make_block_quant_fp8_weights(
|
||||
w2 shape: (E, K, N)
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo(
|
||||
torch.float8_e4m3fn).min
|
||||
fp8_max, fp8_min = (
|
||||
torch.finfo(torch.float8_e4m3fn).max,
|
||||
torch.finfo(torch.float8_e4m3fn).min,
|
||||
)
|
||||
|
||||
# bf16 reference weights
|
||||
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
|
||||
@@ -54,24 +60,16 @@ def make_block_quant_fp8_weights(
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
w1_s = torch.empty(e,
|
||||
n_tiles_w1,
|
||||
k_tiles_w1,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_s = torch.empty(e,
|
||||
n_tiles_w2,
|
||||
k_tiles_w2,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32)
|
||||
w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32)
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
|
||||
block_size=block_size,
|
||||
use_ue8m0=True)
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
|
||||
block_size=block_size,
|
||||
use_ue8m0=True)
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(
|
||||
w1_bf16[i], block_size=block_size, use_ue8m0=True
|
||||
)
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(
|
||||
w2_bf16[i], block_size=block_size, use_ue8m0=True
|
||||
)
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
@@ -81,18 +79,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
|
||||
Triton baseline within tolerance.
|
||||
"""
|
||||
tokens_bf16 = torch.randn(
|
||||
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
|
||||
tokens_bf16 = (
|
||||
torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
|
||||
.clamp_min_(-1)
|
||||
.clamp_max_(1)
|
||||
)
|
||||
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
|
||||
|
||||
# expert weight tensors
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
|
||||
block_size)
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size)
|
||||
|
||||
router_logits = torch.randn(m,
|
||||
num_experts,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
@@ -147,15 +144,14 @@ NUM_EXPERTS = [32]
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(),
|
||||
reason="Requires deep_gemm kernels")
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
_fused_moe_mod = importlib.import_module(
|
||||
"vllm.model_executor.layers.fused_moe.fused_moe")
|
||||
"vllm.model_executor.layers.fused_moe.fused_moe"
|
||||
)
|
||||
|
||||
call_counter = {"cnt": 0}
|
||||
|
||||
@@ -165,8 +161,7 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||
call_counter["cnt"] += 1
|
||||
return orig_fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
|
||||
_spy_deep_gemm_moe_fp8)
|
||||
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
||||
@@ -181,6 +176,7 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||
)
|
||||
|
||||
# ensure that the DeepGEMM path was indeed taken.
|
||||
assert call_counter["cnt"] == 1, \
|
||||
f"DeepGEMM path was not executed during the test. " \
|
||||
assert call_counter["cnt"] == 1, (
|
||||
f"DeepGEMM path was not executed during the test. "
|
||||
f"Call counter: {call_counter['cnt']}"
|
||||
)
|
||||
|
||||
@@ -6,24 +6,28 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
input_to_float8)
|
||||
apply_flashinfer_per_tensor_scale_fp8,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
register_moe_scaling_factors,
|
||||
rotate_flashinfer_fp8_moe_weights,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe(
|
||||
) or not current_platform.has_device_capability(100):
|
||||
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True)
|
||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||
100
|
||||
):
|
||||
pytest.skip(
|
||||
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
NUM_EXPERTS = [16]
|
||||
TOP_KS = [1]
|
||||
@@ -39,8 +43,7 @@ MNK_FACTORS = [
|
||||
(1, 4096, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
@@ -74,18 +77,17 @@ class TestData:
|
||||
layer: torch.nn.Module
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
|
||||
reorder: bool) -> "TestData":
|
||||
hidden_states = torch.randn(
|
||||
(m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, reorder: bool
|
||||
) -> "TestData":
|
||||
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Scale to fp8
|
||||
_, a1_scale = input_to_float8(hidden_states)
|
||||
a1_scale = 1.0 / a1_scale
|
||||
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
|
||||
dtype=torch.float32)
|
||||
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
|
||||
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
|
||||
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
|
||||
|
||||
@@ -102,8 +104,7 @@ class TestData:
|
||||
# flashinfer expects swapped rows for w13
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if reorder:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||
layer.w2_weight)
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
|
||||
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||
layer.intermediate_size_per_partition = n
|
||||
layer.ep_rank = 0
|
||||
@@ -145,7 +146,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
scoring_func="softmax",
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
@@ -178,12 +180,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
top_k=topk,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
apply_router_weight_on_input=True)
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
flashinfer_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
@@ -213,7 +213,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
scoring_func="softmax",
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
@@ -250,7 +251,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
flashinfer_cutlass_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(
|
||||
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
@@ -4,26 +4,33 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
|
||||
FlashInferExperts,
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe(
|
||||
) or not current_platform.has_device_capability(100):
|
||||
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True)
|
||||
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
|
||||
100
|
||||
):
|
||||
pytest.skip(
|
||||
"Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
@@ -44,13 +51,13 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
dtype: torch.dtype):
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
quant_blocksize = 16
|
||||
@@ -66,10 +73,7 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||
|
||||
@@ -87,16 +91,19 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
_, m_k = a_fp4.shape
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize)
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
@@ -104,23 +111,26 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
|
||||
quant_config.w1_scale[idx],
|
||||
(1 / quant_config.g1_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
|
||||
quant_config.w2_scale[idx],
|
||||
(1 / quant_config.g2_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
torch.testing.assert_close(torch_output,
|
||||
flashinfer_output,
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
torch.testing.assert_close(
|
||||
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,20 +17,21 @@ if not has_triton_kernels():
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp,
|
||||
upcast_from_mxfp)
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize)
|
||||
BatchedPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
BatchedOAITritonExperts, triton_kernel_moe_forward)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
BatchedOAITritonExperts,
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.utils import shuffle_weight
|
||||
from vllm.utils import round_up
|
||||
|
||||
@@ -46,13 +47,11 @@ def deshuffle(w: torch.Tensor):
|
||||
def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
randbits = [torch.randperm(E) for _ in range(M)]
|
||||
x_list = [
|
||||
(-1)**i *
|
||||
((16384 +
|
||||
((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
|
||||
(-1) ** i
|
||||
* ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
|
||||
for i, bits in enumerate(randbits)
|
||||
]
|
||||
exp_data = torch.stack(x_list).to(
|
||||
device="cuda") # simulating gate_output (M, E)
|
||||
exp_data = torch.stack(x_list).to(device="cuda") # simulating gate_output (M, E)
|
||||
|
||||
# create input tensor
|
||||
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
|
||||
@@ -120,20 +119,21 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
value=0,
|
||||
)
|
||||
|
||||
w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0)
|
||||
w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0)
|
||||
w1_bias_tri = F.pad(
|
||||
w1_bias_tri, (0, w1_right_pad, 0, 0), mode="constant", value=0
|
||||
)
|
||||
w2_bias_tri = F.pad(
|
||||
w2_bias_tri, (0, w2_right_pad, 0, 0), mode="constant", value=0
|
||||
)
|
||||
|
||||
x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)
|
||||
|
||||
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
|
||||
mx_axis=1)
|
||||
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
w_scale_layout, w_scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps))
|
||||
mx_axis=1, num_warps=num_warps
|
||||
)
|
||||
)
|
||||
|
||||
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
|
||||
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)
|
||||
@@ -141,29 +141,33 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
|
||||
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)
|
||||
|
||||
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
|
||||
**w_layout_opts)
|
||||
w1_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts
|
||||
)
|
||||
w1_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w1_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
|
||||
**w_layout_opts)
|
||||
w2_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts
|
||||
)
|
||||
w2_scale_tri = convert_layout(
|
||||
wrap_torch_tensor(w2_scale_tri),
|
||||
w_scale_layout,
|
||||
**w_scale_layout_opts,
|
||||
)
|
||||
|
||||
pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
|
||||
flex_ctx=FlexCtx(rhs_data=InFlexData()))
|
||||
pc2 = PrecisionConfig(weight_scale=w2_scale_tri,
|
||||
flex_ctx=FlexCtx(rhs_data=InFlexData()))
|
||||
pc1 = PrecisionConfig(
|
||||
weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
pc2 = PrecisionConfig(
|
||||
weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData())
|
||||
)
|
||||
|
||||
# tucuate so the rest can run properly
|
||||
w1 = w1[..., :K, :2 * N]
|
||||
w1 = w1[..., :K, : 2 * N]
|
||||
w2 = w2[..., :N, :K]
|
||||
|
||||
w1 = deshuffle(w1)
|
||||
@@ -261,7 +265,8 @@ class Case:
|
||||
@pytest.mark.parametrize(
|
||||
", ".join(f.name for f in fields(Case)),
|
||||
[
|
||||
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
|
||||
tuple(getattr(case, f.name) for f in fields(Case))
|
||||
for case in [
|
||||
# Case(a_dtype="bf16", w_dtype="bf16"),
|
||||
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
|
||||
Case(a_dtype="bf16", w_dtype="mx4")
|
||||
@@ -321,10 +326,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
gating_output=exp_data,
|
||||
topk=topk,
|
||||
)
|
||||
assert_close(ref=out_ref,
|
||||
tri=out_triton_monolithic,
|
||||
maxtol=0.025,
|
||||
rmstol=0.005)
|
||||
assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
@@ -376,7 +378,8 @@ def batched_moe(
|
||||
@pytest.mark.parametrize(
|
||||
", ".join(f.name for f in fields(Case)),
|
||||
[
|
||||
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
|
||||
tuple(getattr(case, f.name) for f in fields(Case))
|
||||
for case in [
|
||||
# Case(a_dtype="bf16", w_dtype="bf16"),
|
||||
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
|
||||
Case(a_dtype="bf16", w_dtype="mx4")
|
||||
|
||||
@@ -4,16 +4,20 @@
|
||||
|
||||
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk,
|
||||
grouped_topk)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_grouped_topk,
|
||||
grouped_topk,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="This test is skipped on non-CUDA platform.")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize("n_token", [1, 33, 64])
|
||||
@pytest.mark.parametrize("n_hidden", [1024, 2048])
|
||||
@pytest.mark.parametrize("n_expert", [16])
|
||||
@@ -23,23 +27,26 @@ from vllm.platforms import current_platform
|
||||
@pytest.mark.parametrize("topk_group", [2])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
|
||||
n_hidden: int, n_expert: int, topk: int,
|
||||
renormalize: bool, num_expert_group: int,
|
||||
topk_group: int, scoring_func: str,
|
||||
routed_scaling_factor: float, dtype: torch.dtype):
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_grouped_topk(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
n_token: int,
|
||||
n_hidden: int,
|
||||
n_expert: int,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
scoring_func: str,
|
||||
routed_scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
e_score_correction_bias = torch.randn((n_expert, ),
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
|
||||
e_score_correction_bias = torch.randn(
|
||||
(n_expert,), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
@@ -52,7 +59,8 @@ def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
test_topk_weights, test_topk_ids = fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
@@ -63,14 +71,11 @@ def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
torch.testing.assert_close(baseline_topk_weights,
|
||||
test_topk_weights,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(baseline_topk_ids,
|
||||
test_topk_ids,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(
|
||||
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)
|
||||
|
||||
@@ -17,18 +17,29 @@ from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel)
|
||||
from .modular_kernel_tools.common import (
|
||||
Config,
|
||||
RankTensors,
|
||||
WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel,
|
||||
)
|
||||
from .modular_kernel_tools.mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
|
||||
expert_info)
|
||||
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
||||
parallel_launch_with_config)
|
||||
MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES,
|
||||
TestMoEQuantConfig,
|
||||
expert_info,
|
||||
)
|
||||
from .modular_kernel_tools.parallel_utils import (
|
||||
ProcessGroupInfo,
|
||||
parallel_launch_with_config,
|
||||
)
|
||||
|
||||
has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
|
||||
or has_flashinfer_cutlass_fused_moe())
|
||||
has_any_multi_gpu_package = (
|
||||
has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe()
|
||||
)
|
||||
|
||||
meets_multi_gpu_requirements = pytest.mark.skipif(
|
||||
not has_any_multi_gpu_package,
|
||||
@@ -64,9 +75,9 @@ def rank_worker(
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if base_config.fused_moe_chunk_size is not None:
|
||||
assert (
|
||||
base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
@@ -93,8 +104,7 @@ def rank_worker(
|
||||
rank_tensors = RankTensors.make(config, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
|
||||
rank_tensors)
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(config, weights, rank_tensors)
|
||||
@@ -115,10 +125,10 @@ def rank_worker(
|
||||
if len(exceptions) > 0:
|
||||
raise RuntimeError(
|
||||
f"{len(exceptions)} of {count} tests failed in child process, "
|
||||
f"rank={pgi.rank}.")
|
||||
f"rank={pgi.rank}."
|
||||
)
|
||||
else:
|
||||
print(f"{count} of {count} tests passed in child process, "
|
||||
f"rank={pgi.rank}.")
|
||||
print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
|
||||
|
||||
|
||||
def run(config: Config, verbose: bool):
|
||||
@@ -127,8 +137,9 @@ def run(config: Config, verbose: bool):
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
||||
env_dict, config, weights, verbose)
|
||||
parallel_launch_with_config(
|
||||
config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose
|
||||
)
|
||||
|
||||
|
||||
Ms = [32, 64]
|
||||
@@ -149,8 +160,9 @@ def is_nyi_config(config: Config) -> bool:
|
||||
if info.needs_matching_quant:
|
||||
# The triton kernels expect both per-act-token-quant and
|
||||
# per-out-ch-quant or neither.
|
||||
unsupported_quant_config = ((config.is_per_act_token_quant +
|
||||
config.is_per_out_ch_quant) == 1)
|
||||
unsupported_quant_config = (
|
||||
config.is_per_act_token_quant + config.is_per_out_ch_quant
|
||||
) == 1
|
||||
return unsupported_quant_config
|
||||
|
||||
return not info.supports_expert_map
|
||||
@@ -162,19 +174,25 @@ def is_nyi_config(config: Config) -> bool:
|
||||
@pytest.mark.parametrize("dtype", DTYPEs)
|
||||
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
||||
@pytest.mark.parametrize(
|
||||
"combination",
|
||||
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||
"combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
|
||||
)
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@meets_multi_gpu_requirements
|
||||
def test_modular_kernel_combinations_multigpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||
|
||||
k: int,
|
||||
n: int,
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[
|
||||
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
|
||||
],
|
||||
fused_moe_chunk_size: Optional[int],
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
@@ -195,7 +213,7 @@ def test_modular_kernel_combinations_multigpu(
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
verbosity = pytestconfig.getoption('verbose')
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
@@ -205,16 +223,23 @@ def test_modular_kernel_combinations_multigpu(
|
||||
@pytest.mark.parametrize("dtype", DTYPEs)
|
||||
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
||||
@pytest.mark.parametrize(
|
||||
"combination",
|
||||
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||
"combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)
|
||||
)
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [1])
|
||||
def test_modular_kernel_combinations_singlegpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||
k: int,
|
||||
n: int,
|
||||
e: int,
|
||||
dtype: torch.dtype,
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[
|
||||
mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute
|
||||
],
|
||||
fused_moe_chunk_size: Optional[int],
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
@@ -235,19 +260,21 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
verbosity = pytestconfig.getoption('verbose')
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
run(config, verbosity > 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# Ability to test individual PrepareAndFinalize and FusedExperts combination
|
||||
from .modular_kernel_tools.cli_args import (make_config,
|
||||
make_config_arg_parser)
|
||||
parser = make_config_arg_parser(description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
))
|
||||
from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser
|
||||
|
||||
parser = make_config_arg_parser(
|
||||
description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " # noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = make_config(args)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
@@ -21,22 +22,32 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
fused_topk,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
fused_moe as iterative_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_permute_bias)
|
||||
marlin_permute_bias,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like)
|
||||
rand_marlin_weight_mxfp4_like,
|
||||
rand_marlin_weight_nvfp4_like,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
marlin_quant_fp8_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
awq_marlin_quantize, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights)
|
||||
awq_marlin_quantize,
|
||||
marlin_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
@@ -87,13 +98,15 @@ def run_moe_test(
|
||||
if isinstance(baseline, torch.Tensor):
|
||||
baseline_output = baseline
|
||||
else:
|
||||
baseline_output = baseline(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
baseline_output = baseline(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
@@ -105,34 +118,35 @@ def run_moe_test(
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(score, 0)
|
||||
|
||||
test_output = moe_fn(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
test_output = moe_fn(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
if use_cudagraph:
|
||||
test_output.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
test_output = moe_fn(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
test_output = moe_fn(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(test_output,
|
||||
baseline_output,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
|
||||
|
||||
return baseline_output
|
||||
|
||||
@@ -176,11 +190,8 @@ def test_fused_moe(
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
|
||||
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
@@ -204,13 +215,15 @@ def test_fused_moe(
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
return m_fused_moe_fn(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
return m_fused_moe_fn(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
|
||||
|
||||
@@ -234,19 +247,22 @@ def test_fused_moe(
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (n >= 1024 and k >= 1024
|
||||
and current_platform.is_cuda_alike())
|
||||
use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
baseline_output = runner(torch_moe, iterative_moe)
|
||||
runner(baseline_output,
|
||||
fused_moe_fn,
|
||||
use_compile=use_compile,
|
||||
use_cudagraph=use_cudagraph)
|
||||
runner(baseline_output,
|
||||
m_fused_moe,
|
||||
use_compile=use_compile,
|
||||
use_cudagraph=use_cudagraph)
|
||||
runner(
|
||||
baseline_output,
|
||||
fused_moe_fn,
|
||||
use_compile=use_compile,
|
||||
use_cudagraph=use_cudagraph,
|
||||
)
|
||||
runner(
|
||||
baseline_output,
|
||||
m_fused_moe,
|
||||
use_compile=use_compile,
|
||||
use_cudagraph=use_cudagraph,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
|
||||
@@ -257,9 +273,18 @@ def test_fused_moe(
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("weight_bits", [4, 8])
|
||||
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
ep_size: int, dtype: torch.dtype, group_size: int,
|
||||
has_zp: bool, weight_bits: int):
|
||||
def test_fused_moe_wn16(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
group_size: int,
|
||||
has_zp: bool,
|
||||
weight_bits: int,
|
||||
):
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
@@ -274,35 +299,40 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
|
||||
w1_ref = w1.clone()
|
||||
w2_ref = w2.clone()
|
||||
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w2_qweight = torch.empty((e, k, n // pack_factor),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w1_scales = torch.empty((e, 2 * n, k // group_size),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
w2_scales = torch.empty((e, k, n // group_size),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w1_qweight = torch.empty(
|
||||
(e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
|
||||
)
|
||||
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
|
||||
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
|
||||
w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
|
||||
w1_qzeros = torch.empty(
|
||||
(e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
|
||||
)
|
||||
w2_qzeros = torch.empty(
|
||||
(e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
|
||||
)
|
||||
|
||||
for i in range(e * 2):
|
||||
expert_id = i % e
|
||||
if i // e == 0:
|
||||
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
||||
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
|
||||
w, w_ref, w_qweight, w_scales, w_qzeros = (
|
||||
w1,
|
||||
w1_ref,
|
||||
w1_qweight,
|
||||
w1_scales,
|
||||
w1_qzeros,
|
||||
)
|
||||
else:
|
||||
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
||||
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
|
||||
w, w_ref, w_qweight, w_scales, w_qzeros = (
|
||||
w2,
|
||||
w2_ref,
|
||||
w2_qweight,
|
||||
w2_scales,
|
||||
w2_qzeros,
|
||||
)
|
||||
weight, qweight, scales, qzeros = quantize_weights(
|
||||
w[expert_id].T, quant_type, group_size, has_zp, False)
|
||||
w[expert_id].T, quant_type, group_size, has_zp, False
|
||||
)
|
||||
weight = weight.T
|
||||
qweight = qweight.T.contiguous().to(torch.uint8)
|
||||
scales = scales.T
|
||||
@@ -321,11 +351,8 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||
e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
|
||||
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
w1_ref = w1_ref[e_ids]
|
||||
w2_ref = w2_ref[e_ids]
|
||||
@@ -344,28 +371,27 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
assert weight_bits == 8
|
||||
quant_config_builder = int8_w8a16_moe_quant_config
|
||||
|
||||
quant_config = quant_config_builder(w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
quant_config = quant_config_builder(
|
||||
w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size],
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
triton_output = fused_moe(a,
|
||||
w1_qweight,
|
||||
w2_qweight,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
quant_config=quant_config)
|
||||
torch_output = torch_moe(a,
|
||||
w1_ref,
|
||||
w2_ref,
|
||||
score,
|
||||
topk,
|
||||
expert_map=e_map)
|
||||
triton_output = fused_moe(
|
||||
a,
|
||||
w1_qweight,
|
||||
w2_qweight,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map)
|
||||
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
@@ -373,16 +399,20 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
|
||||
use_rocm_aiter: bool, monkeypatch):
|
||||
def test_mixtral_moe(
|
||||
dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
|
||||
):
|
||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||
huggingface."""
|
||||
|
||||
# clear the cache before every test
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
|
||||
is_rocm_aiter_moe_enabled.cache_clear()
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
@@ -390,17 +420,16 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
|
||||
if dtype == torch.float32:
|
||||
pytest.skip("AITER ROCm test skip for float32")
|
||||
|
||||
monkeypatch.setenv('RANK', "0")
|
||||
monkeypatch.setenv('LOCAL_RANK', "0")
|
||||
monkeypatch.setenv('WORLD_SIZE', "1")
|
||||
monkeypatch.setenv('MASTER_ADDR', 'localhost')
|
||||
monkeypatch.setenv('MASTER_PORT', '12345')
|
||||
monkeypatch.setenv("RANK", "0")
|
||||
monkeypatch.setenv("LOCAL_RANK", "0")
|
||||
monkeypatch.setenv("WORLD_SIZE", "1")
|
||||
monkeypatch.setenv("MASTER_ADDR", "localhost")
|
||||
monkeypatch.setenv("MASTER_PORT", "12345")
|
||||
init_distributed_environment()
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
vllm_config.compilation_config.static_forward_context = dict()
|
||||
with (set_current_vllm_config(vllm_config),
|
||||
set_forward_context(None, vllm_config)):
|
||||
with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
|
||||
config = MixtralConfig()
|
||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||
vllm_moe = MixtralMoE(
|
||||
@@ -416,27 +445,30 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
|
||||
# Load the weights
|
||||
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data)
|
||||
weights = (
|
||||
hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data,
|
||||
)
|
||||
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
|
||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||
hf_inputs = torch.randn(
|
||||
(1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
# vLLM uses 1D query [num_tokens, hidden_dim]
|
||||
vllm_inputs = hf_inputs.flatten(0, 1)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
vllm_moe.experts.w13_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
|
||||
0:-128],
|
||||
requires_grad=False)
|
||||
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
|
||||
0:-128],
|
||||
requires_grad=False)
|
||||
vllm_moe.experts.w13_weight = Parameter(
|
||||
F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[
|
||||
..., 0:-128
|
||||
],
|
||||
requires_grad=False,
|
||||
)
|
||||
vllm_moe.experts.w2_weight = Parameter(
|
||||
F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -453,19 +485,21 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool,
|
||||
if use_rocm_aiter:
|
||||
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
|
||||
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=0.01,
|
||||
atol=100)
|
||||
torch.testing.assert_close(
|
||||
hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100
|
||||
)
|
||||
else:
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
torch.testing.assert_close(
|
||||
hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype],
|
||||
)
|
||||
|
||||
|
||||
def marlin_moe_generate_valid_test_cases():
|
||||
import itertools
|
||||
|
||||
m_list = [1, 123, 666]
|
||||
n_list = [128, 1024]
|
||||
k_list = [256, 2048]
|
||||
@@ -484,16 +518,24 @@ def marlin_moe_generate_valid_test_cases():
|
||||
]
|
||||
is_k_full_list = [True, False]
|
||||
|
||||
all_combinations = itertools.product(m_list, n_list, k_list, e_list,
|
||||
topk_list, ep_size_list, dtype_list,
|
||||
group_size_list, act_order_list,
|
||||
quant_type_list, is_k_full_list)
|
||||
all_combinations = itertools.product(
|
||||
m_list,
|
||||
n_list,
|
||||
k_list,
|
||||
e_list,
|
||||
topk_list,
|
||||
ep_size_list,
|
||||
dtype_list,
|
||||
group_size_list,
|
||||
act_order_list,
|
||||
quant_type_list,
|
||||
is_k_full_list,
|
||||
)
|
||||
|
||||
def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order,
|
||||
quant_type, is_k_full):
|
||||
|
||||
if quant_type == scalar_types.float8_e4m3fn and \
|
||||
group_size not in [-1, 128]:
|
||||
def is_invalid(
|
||||
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
|
||||
):
|
||||
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
|
||||
return False
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size not in [16, 32]:
|
||||
@@ -522,9 +564,10 @@ def marlin_moe_generate_valid_test_cases():
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
|
||||
"act_order, quant_type, is_k_full"),
|
||||
marlin_moe_generate_valid_test_cases())
|
||||
@pytest.mark.parametrize(
|
||||
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
|
||||
marlin_moe_generate_valid_test_cases(),
|
||||
)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
@@ -549,7 +592,7 @@ def test_fused_marlin_moe(
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
|
||||
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||
e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
@@ -567,11 +610,13 @@ def test_fused_marlin_moe(
|
||||
for i in range(w1.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref1, qweight1, scales1, global_scale1 = \
|
||||
w_ref1, qweight1, scales1, global_scale1 = (
|
||||
rand_marlin_weight_nvfp4_like(w1[i], group_size)
|
||||
)
|
||||
else:
|
||||
w_ref1, qweight1, scales1 = \
|
||||
rand_marlin_weight_mxfp4_like(w1[i], group_size)
|
||||
w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
|
||||
w1[i], group_size
|
||||
)
|
||||
global_scale1 = None
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
@@ -580,14 +625,14 @@ def test_fused_marlin_moe(
|
||||
if global_scale1 is not None:
|
||||
global_scale1_l.append(global_scale1)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
|
||||
w1[i], group_size)
|
||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
elif has_zp:
|
||||
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size)
|
||||
w1[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
@@ -595,9 +640,9 @@ def test_fused_marlin_moe(
|
||||
zeros1_l.append(zeros1)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
|
||||
marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
@@ -624,11 +669,13 @@ def test_fused_marlin_moe(
|
||||
for i in range(w2.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref2, qweight2, scales2, global_scale2 = \
|
||||
w_ref2, qweight2, scales2, global_scale2 = (
|
||||
rand_marlin_weight_nvfp4_like(w2[i], group_size)
|
||||
)
|
||||
else:
|
||||
w_ref2, qweight2, scales2 = \
|
||||
rand_marlin_weight_mxfp4_like(w2[i], group_size)
|
||||
w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
|
||||
w2[i], group_size
|
||||
)
|
||||
global_scale2 = None
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
@@ -637,14 +684,14 @@ def test_fused_marlin_moe(
|
||||
if global_scale2 is not None:
|
||||
global_scale2_l.append(global_scale2)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
|
||||
w2[i], group_size)
|
||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
elif has_zp:
|
||||
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size)
|
||||
w2[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
@@ -652,9 +699,9 @@ def test_fused_marlin_moe(
|
||||
zeros2_l.append(zeros2)
|
||||
else:
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
|
||||
marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
@@ -675,12 +722,7 @@ def test_fused_marlin_moe(
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a,
|
||||
w_ref1,
|
||||
w_ref2,
|
||||
score,
|
||||
topk,
|
||||
expert_map=e_map)
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
|
||||
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
@@ -704,7 +746,8 @@ def test_fused_marlin_moe(
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full)
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
@@ -738,9 +781,9 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
|
||||
marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
@@ -767,9 +810,9 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
|
||||
marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
@@ -792,8 +835,7 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
|
||||
b_bias2)
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
|
||||
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
@@ -817,7 +859,8 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full)
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
@@ -825,34 +868,36 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
def test_moe_align_block_size_opcheck():
|
||||
num_experts = 4
|
||||
block_size = 4
|
||||
topk_ids = torch.randint(0,
|
||||
num_experts, (3, 4),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
|
||||
opcheck(torch.ops._moe_C.moe_align_block_size,
|
||||
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
|
||||
num_tokens_post_pad))
|
||||
opcheck(
|
||||
torch.ops._moe_C.moe_align_block_size,
|
||||
(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
|
||||
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
|
||||
|
||||
@@ -11,7 +11,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
@@ -60,30 +61,33 @@ def _verify_expert_level_sorting(
|
||||
in topk_ids in the final sorted_ids however this does not impact quality.
|
||||
"""
|
||||
# Group tokens by expert from the golden implementation
|
||||
golden_expert_tokens = _group_tokens_by_expert(golden_sorted_ids,
|
||||
expert_ids, block_size,
|
||||
valid_length, total_tokens)
|
||||
golden_expert_tokens = _group_tokens_by_expert(
|
||||
golden_sorted_ids, expert_ids, block_size, valid_length, total_tokens
|
||||
)
|
||||
|
||||
actual_expert_tokens = _group_tokens_by_expert(actual_sorted_ids,
|
||||
expert_ids, block_size,
|
||||
valid_length, total_tokens)
|
||||
actual_expert_tokens = _group_tokens_by_expert(
|
||||
actual_sorted_ids, expert_ids, block_size, valid_length, total_tokens
|
||||
)
|
||||
|
||||
assert set(golden_expert_tokens.keys()) == set(
|
||||
actual_expert_tokens.keys()), (
|
||||
f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, "
|
||||
f"actual={set(actual_expert_tokens.keys())}")
|
||||
assert set(golden_expert_tokens.keys()) == set(actual_expert_tokens.keys()), (
|
||||
f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, "
|
||||
f"actual={set(actual_expert_tokens.keys())}"
|
||||
)
|
||||
|
||||
for expert_id in golden_expert_tokens:
|
||||
golden_tokens = torch.tensor(golden_expert_tokens[expert_id],
|
||||
device=actual_sorted_ids.device)
|
||||
actual_tokens = torch.tensor(actual_expert_tokens[expert_id],
|
||||
device=actual_sorted_ids.device)
|
||||
golden_tokens = torch.tensor(
|
||||
golden_expert_tokens[expert_id], device=actual_sorted_ids.device
|
||||
)
|
||||
actual_tokens = torch.tensor(
|
||||
actual_expert_tokens[expert_id], device=actual_sorted_ids.device
|
||||
)
|
||||
assert torch.equal(
|
||||
torch.sort(golden_tokens)[0],
|
||||
torch.sort(actual_tokens)[0]), (
|
||||
f"Expert {expert_id} token mismatch: "
|
||||
f"golden={golden_expert_tokens[expert_id]}, "
|
||||
f"actual={actual_expert_tokens[expert_id]}")
|
||||
torch.sort(golden_tokens)[0], torch.sort(actual_tokens)[0]
|
||||
), (
|
||||
f"Expert {expert_id} token mismatch: "
|
||||
f"golden={golden_expert_tokens[expert_id]}, "
|
||||
f"actual={actual_expert_tokens[expert_id]}"
|
||||
)
|
||||
|
||||
|
||||
def torch_moe_align_block_size(
|
||||
@@ -104,40 +108,38 @@ def torch_moe_align_block_size(
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
|
||||
flattened_token_indices = torch.arange(topk_ids.numel(),
|
||||
device=topk_ids.device,
|
||||
dtype=torch.int32)
|
||||
flattened_token_indices = torch.arange(
|
||||
topk_ids.numel(), device=topk_ids.device, dtype=torch.int32
|
||||
)
|
||||
flattened_expert_ids = topk_ids.flatten()
|
||||
sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids,
|
||||
stable=True)
|
||||
sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, stable=True)
|
||||
sorted_token_indices = flattened_token_indices[sort_indices]
|
||||
|
||||
expert_token_counts = torch.zeros(num_experts,
|
||||
dtype=torch.int64,
|
||||
device=topk_ids.device)
|
||||
expert_token_counts = torch.zeros(
|
||||
num_experts, dtype=torch.int64, device=topk_ids.device
|
||||
)
|
||||
for expert_id in range(num_experts):
|
||||
mask = sorted_expert_ids == expert_id
|
||||
expert_token_counts[expert_id] = mask.sum()
|
||||
|
||||
expert_padded_counts = torch.zeros(num_experts,
|
||||
dtype=torch.int64,
|
||||
device=topk_ids.device)
|
||||
expert_padded_counts = torch.zeros(
|
||||
num_experts, dtype=torch.int64, device=topk_ids.device
|
||||
)
|
||||
for expert_id in range(num_experts):
|
||||
original_count = expert_token_counts[expert_id]
|
||||
if original_count > 0:
|
||||
expert_padded_counts[expert_id] = (
|
||||
(original_count + block_size - 1) // block_size) * block_size
|
||||
(original_count + block_size - 1) // block_size
|
||||
) * block_size
|
||||
|
||||
sorted_token_ids = torch.full(
|
||||
(max_num_tokens_padded, ),
|
||||
(max_num_tokens_padded,),
|
||||
topk_ids.numel(),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size
|
||||
expert_ids = torch.zeros(max_num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device)
|
||||
|
||||
current_pos = 0
|
||||
current_block = 0
|
||||
@@ -147,20 +149,20 @@ def torch_moe_align_block_size(
|
||||
num_expert_tokens = expert_tokens.shape[0]
|
||||
|
||||
if num_expert_tokens > 0:
|
||||
sorted_token_ids[current_pos:current_pos +
|
||||
num_expert_tokens] = (expert_tokens)
|
||||
sorted_token_ids[current_pos : current_pos + num_expert_tokens] = (
|
||||
expert_tokens
|
||||
)
|
||||
|
||||
expert_blocks_needed = expert_padded_counts[expert_id] // block_size
|
||||
expert_ids[current_block:current_block +
|
||||
expert_blocks_needed] = (expert_id)
|
||||
expert_ids[current_block : current_block + expert_blocks_needed] = expert_id
|
||||
|
||||
current_pos += expert_padded_counts[expert_id]
|
||||
current_block += expert_blocks_needed
|
||||
|
||||
total_padded_tokens = expert_padded_counts.sum()
|
||||
num_tokens_post_pad = torch.tensor([total_padded_tokens],
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.tensor(
|
||||
[total_padded_tokens], dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
@@ -173,37 +175,32 @@ def torch_moe_align_block_size(
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size(m: int, topk: int, num_experts: int,
|
||||
block_size: int, pad_sorted_ids: bool):
|
||||
def test_moe_align_block_size(
|
||||
m: int, topk: int, num_experts: int, block_size: int, pad_sorted_ids: bool
|
||||
):
|
||||
"""Test moe_align_block_size without expert mapping"""
|
||||
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
|
||||
for i in range(m):
|
||||
experts = torch.randperm(num_experts, device="cuda")[:topk]
|
||||
topk_ids[i] = experts
|
||||
|
||||
actual_sorted_ids, actual_expert_ids, actual_num_tokens = (
|
||||
moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
pad_sorted_ids=pad_sorted_ids,
|
||||
))
|
||||
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
pad_sorted_ids=pad_sorted_ids,
|
||||
)
|
||||
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
|
||||
torch_moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
pad_sorted_ids=pad_sorted_ids,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(actual_num_tokens,
|
||||
golden_num_tokens,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(actual_expert_ids,
|
||||
golden_expert_ids,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
|
||||
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
|
||||
|
||||
# For sorted_token_ids, verify block-level correctness rather than exact
|
||||
# order Tokens within each expert's blocks can be in any order, but expert
|
||||
@@ -219,16 +216,18 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
|
||||
|
||||
total_tokens = m * topk
|
||||
assert actual_num_tokens.item() % block_size == 0, (
|
||||
"num_tokens_post_pad should be divisible by block_size")
|
||||
"num_tokens_post_pad should be divisible by block_size"
|
||||
)
|
||||
assert actual_num_tokens.item() >= total_tokens, (
|
||||
"num_tokens_post_pad should be at least total_tokens")
|
||||
"num_tokens_post_pad should be at least total_tokens"
|
||||
)
|
||||
valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens]
|
||||
assert len(valid_tokens) == total_tokens, (
|
||||
f"Should have exactly {total_tokens} valid tokens, "
|
||||
f"got {len(valid_tokens)}")
|
||||
assert (actual_expert_ids >= 0).all() and (
|
||||
actual_expert_ids
|
||||
< num_experts).all(), "expert_ids should contain valid expert indices"
|
||||
f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}"
|
||||
)
|
||||
assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), (
|
||||
"expert_ids should contain valid expert indices"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [16, 32])
|
||||
@@ -236,46 +235,37 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
|
||||
@pytest.mark.parametrize("num_experts", [8])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size_with_expert_map(m: int, topk: int,
|
||||
num_experts: int,
|
||||
block_size: int):
|
||||
def test_moe_align_block_size_with_expert_map(
|
||||
m: int, topk: int, num_experts: int, block_size: int
|
||||
):
|
||||
"""Test moe_align_block_size with expert mapping (EP scenario)"""
|
||||
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
|
||||
for i in range(m):
|
||||
experts = torch.randperm(num_experts, device="cuda")[:topk]
|
||||
topk_ids[i] = experts
|
||||
|
||||
expert_map = torch.full((num_experts, ),
|
||||
-1,
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
|
||||
local_experts = list(range(0, num_experts, 2))
|
||||
for i, expert_id in enumerate(local_experts):
|
||||
expert_map[expert_id] = i
|
||||
|
||||
actual_sorted_ids, actual_expert_ids, actual_num_tokens = (
|
||||
moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
))
|
||||
actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
|
||||
torch_moe_align_block_size(
|
||||
topk_ids=topk_ids,
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(actual_num_tokens,
|
||||
golden_num_tokens,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(actual_expert_ids,
|
||||
golden_expert_ids,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0)
|
||||
torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0)
|
||||
_verify_expert_level_sorting(
|
||||
actual_sorted_ids,
|
||||
golden_sorted_ids,
|
||||
@@ -290,26 +280,25 @@ def test_moe_align_block_size_deterministic():
|
||||
m, topk, num_experts, block_size = 128, 2, 32, 64
|
||||
|
||||
torch.manual_seed(42)
|
||||
topk_ids = torch.randint(0,
|
||||
num_experts, (m, topk),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
topk_ids = torch.randint(
|
||||
0, num_experts, (m, topk), device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
# expect the results to be reproducible
|
||||
results = []
|
||||
for _ in range(5):
|
||||
sorted_ids, expert_ids, num_tokens = moe_align_block_size(
|
||||
topk_ids=topk_ids, block_size=block_size, num_experts=num_experts)
|
||||
results.append(
|
||||
(sorted_ids.clone(), expert_ids.clone(), num_tokens.clone()))
|
||||
topk_ids=topk_ids, block_size=block_size, num_experts=num_experts
|
||||
)
|
||||
results.append((sorted_ids.clone(), expert_ids.clone(), num_tokens.clone()))
|
||||
|
||||
for i in range(1, len(results)):
|
||||
assert torch.equal(
|
||||
results[0][0],
|
||||
results[i][0]), ("sorted_ids should be deterministic")
|
||||
assert torch.equal(
|
||||
results[0][1],
|
||||
results[i][1]), ("expert_ids should be deterministic")
|
||||
assert torch.equal(
|
||||
results[0][2],
|
||||
results[i][2]), ("num_tokens should be deterministic")
|
||||
assert torch.equal(results[0][0], results[i][0]), (
|
||||
"sorted_ids should be deterministic"
|
||||
)
|
||||
assert torch.equal(results[0][1], results[i][1]), (
|
||||
"expert_ids should be deterministic"
|
||||
)
|
||||
assert torch.equal(results[0][2], results[i][2]), (
|
||||
"num_tokens should be deterministic"
|
||||
)
|
||||
|
||||
@@ -14,7 +14,10 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_permute, moe_permute_unpermute_supported, moe_unpermute)
|
||||
moe_permute,
|
||||
moe_permute_unpermute_supported,
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [16, 64, 256]
|
||||
@@ -24,35 +27,34 @@ current_platform.seed_everything(0)
|
||||
|
||||
|
||||
def torch_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# token_expert_indices: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
start_expert: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# token_expert_indices: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
start_expert: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
fill_invalid_expert: int = -1,
|
||||
) -> list[torch.Tensor]:
|
||||
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
|
||||
if expert_map is not None:
|
||||
is_local_expert = (expert_map[topk_ids] != -1)
|
||||
not_local_expert = (expert_map[topk_ids] == -1)
|
||||
topk_ids = is_local_expert * (
|
||||
topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert)
|
||||
token_expert_indices = torch.arange(0,
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).reshape(
|
||||
(n_token, topk))
|
||||
is_local_expert = expert_map[topk_ids] != -1
|
||||
not_local_expert = expert_map[topk_ids] == -1
|
||||
topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * (
|
||||
topk_ids + n_expert
|
||||
)
|
||||
token_expert_indices = torch.arange(
|
||||
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
|
||||
).reshape((n_token, topk))
|
||||
|
||||
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(),
|
||||
stable=True)
|
||||
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True)
|
||||
dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices]
|
||||
|
||||
expert_first_token_offset = torch.zeros(n_local_expert + 1,
|
||||
dtype=torch.int64,
|
||||
device="cuda")
|
||||
expert_first_token_offset = torch.zeros(
|
||||
n_local_expert + 1, dtype=torch.int64, device="cuda"
|
||||
)
|
||||
idx = 0
|
||||
for i in range(0, n_local_expert):
|
||||
cnt = 0
|
||||
@@ -64,116 +66,133 @@ def torch_permute(
|
||||
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
|
||||
valid_row_idx = []
|
||||
if align_block_size is None:
|
||||
|
||||
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map //
|
||||
topk, ...]
|
||||
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
|
||||
permuted_row_size = permuted_hidden_states.shape[0]
|
||||
m_indices = torch.empty(permuted_row_size,
|
||||
device="cuda",
|
||||
dtype=torch.int32).fill_(fill_invalid_expert)
|
||||
m_indices = torch.empty(
|
||||
permuted_row_size, device="cuda", dtype=torch.int32
|
||||
).fill_(fill_invalid_expert)
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
m_indices[first_token_offset:last_token_offset] = i - 1
|
||||
src_row_id2dst_row_id_map = torch.arange(
|
||||
0, n_token * topk, device="cuda",
|
||||
dtype=torch.int32)[src2dst_idx].reshape((n_token, topk))
|
||||
0, n_token * topk, device="cuda", dtype=torch.int32
|
||||
)[src2dst_idx].reshape((n_token, topk))
|
||||
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
|
||||
dst_row_id2src_row_id_map[
|
||||
expert_first_token_offset[-1]:] = n_token * topk
|
||||
dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk
|
||||
return [
|
||||
permuted_hidden_states, expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices,
|
||||
valid_row_idx
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map,
|
||||
dst_row_id2src_row_id_map,
|
||||
m_indices,
|
||||
valid_row_idx,
|
||||
]
|
||||
else:
|
||||
permuted_row_size = (topk * n_token + n_expert *
|
||||
(align_block_size - 1) + align_block_size -
|
||||
1) // align_block_size * align_block_size
|
||||
permuted_idx = torch.full((permuted_row_size, ),
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
permuted_hidden_states = torch.empty((permuted_row_size, n_hidden),
|
||||
device="cuda",
|
||||
dtype=hidden_states.dtype)
|
||||
align_src_row_id2dst_row_id = torch.empty(n_token * topk,
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
align_expert_first_token_offset = torch.zeros_like(
|
||||
expert_first_token_offset)
|
||||
m_indices = torch.empty(permuted_row_size,
|
||||
device="cuda",
|
||||
dtype=torch.int32).fill_(fill_invalid_expert)
|
||||
permuted_row_size = (
|
||||
(topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
permuted_idx = torch.full(
|
||||
(permuted_row_size,),
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
|
||||
)
|
||||
align_src_row_id2dst_row_id = torch.empty(
|
||||
n_token * topk, device="cuda", dtype=torch.int32
|
||||
)
|
||||
align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
|
||||
m_indices = torch.empty(
|
||||
permuted_row_size, device="cuda", dtype=torch.int32
|
||||
).fill_(fill_invalid_expert)
|
||||
# get align_permuted_hidden_states,
|
||||
# valid row_idx and align_expert_first_token_offset
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
n_token_in_expert = last_token_offset - first_token_offset
|
||||
align_expert_first_token_offset[
|
||||
i] = align_expert_first_token_offset[
|
||||
i - 1] + (n_token_in_expert + align_block_size -
|
||||
1) // align_block_size * align_block_size
|
||||
align_expert_first_token_offset[i] = (
|
||||
align_expert_first_token_offset[i - 1]
|
||||
+ (n_token_in_expert + align_block_size - 1)
|
||||
// align_block_size
|
||||
* align_block_size
|
||||
)
|
||||
align_first_token_offset = align_expert_first_token_offset[i - 1]
|
||||
align_last_token_offset = align_expert_first_token_offset[i]
|
||||
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
|
||||
first_token_offset:first_token_offset + n_token_in_expert]
|
||||
first_token_offset : first_token_offset + n_token_in_expert
|
||||
]
|
||||
# store token in current expert with align_first_token_offset
|
||||
permuted_hidden_states[align_first_token_offset:\
|
||||
align_first_token_offset+n_token_in_expert,\
|
||||
...] = hidden_states[\
|
||||
dst_row_id2src_row_id_in_expert // topk,\
|
||||
...]
|
||||
permuted_idx[align_first_token_offset:\
|
||||
align_first_token_offset+\
|
||||
n_token_in_expert] = dst_row_id2src_row_id_in_expert
|
||||
permuted_hidden_states[
|
||||
align_first_token_offset : align_first_token_offset + n_token_in_expert,
|
||||
...,
|
||||
] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
|
||||
permuted_idx[
|
||||
align_first_token_offset : align_first_token_offset + n_token_in_expert
|
||||
] = dst_row_id2src_row_id_in_expert
|
||||
# set current expert m_indices
|
||||
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
|
||||
valid_row_idx += [
|
||||
i for i in range(align_first_token_offset,
|
||||
align_first_token_offset + n_token_in_expert)
|
||||
i
|
||||
for i in range(
|
||||
align_first_token_offset,
|
||||
align_first_token_offset + n_token_in_expert,
|
||||
)
|
||||
]
|
||||
# get align_src_row_id2dst_row_id
|
||||
for i in range(n_token * topk):
|
||||
eid = sorted_topk_ids[i]
|
||||
if (eid >= n_local_expert):
|
||||
if eid >= n_local_expert:
|
||||
# check token not in local expert
|
||||
align_src_row_id2dst_row_id[
|
||||
i] = align_expert_first_token_offset[-1]
|
||||
align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
|
||||
continue
|
||||
first_token_offset = expert_first_token_offset[eid]
|
||||
align_first_token_offset = align_expert_first_token_offset[eid]
|
||||
token_offset = i - first_token_offset
|
||||
align_src_row_id2dst_row_id[
|
||||
i] = align_first_token_offset + token_offset
|
||||
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\
|
||||
src2dst_idx].reshape((n_token, topk))
|
||||
align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
|
||||
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
|
||||
(n_token, topk)
|
||||
)
|
||||
return [
|
||||
permuted_hidden_states, align_expert_first_token_offset,
|
||||
align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx
|
||||
permuted_hidden_states,
|
||||
align_expert_first_token_offset,
|
||||
align_src_row_id2dst_row_id,
|
||||
permuted_idx,
|
||||
m_indices,
|
||||
valid_row_idx,
|
||||
]
|
||||
|
||||
|
||||
def torch_unpermute(permuted_hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
src_row_id2dst_row_id_map: torch.Tensor,
|
||||
valid_row_idx: torch.Tensor, topk: int,
|
||||
n_expert: int) -> torch.Tensor:
|
||||
def torch_unpermute(
|
||||
permuted_hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
src_row_id2dst_row_id_map: torch.Tensor,
|
||||
valid_row_idx: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
) -> torch.Tensor:
|
||||
# ignore invalid row
|
||||
n_hidden = permuted_hidden_states.shape[1]
|
||||
mask = torch.zeros(permuted_hidden_states.shape[0],
|
||||
dtype=bool,
|
||||
device="cuda")
|
||||
mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda")
|
||||
mask[valid_row_idx] = True
|
||||
permuted_hidden_states[~mask] = 0
|
||||
|
||||
permuted_hidden_states = permuted_hidden_states[
|
||||
src_row_id2dst_row_id_map.flatten(), ...]
|
||||
src_row_id2dst_row_id_map.flatten(), ...
|
||||
]
|
||||
permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden)
|
||||
output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to(
|
||||
permuted_hidden_states.dtype)
|
||||
output = (
|
||||
(permuted_hidden_states * topk_weights.unsqueeze(2))
|
||||
.sum(1)
|
||||
.to(permuted_hidden_states.dtype)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@@ -184,59 +203,76 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("align_block_size", [None, 128])
|
||||
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
||||
n_expert: int, ep_size: int, dtype: torch.dtype,
|
||||
align_block_size: Optional[int]):
|
||||
def test_moe_permute_unpermute(
|
||||
n_token: int,
|
||||
n_hidden: int,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
align_block_size: Optional[int],
|
||||
):
|
||||
if not moe_permute_unpermute_supported():
|
||||
pytest.skip("moe_permute_unpermute is not supported on this platform.")
|
||||
fill_invalid_expert = 0
|
||||
ep_rank = np.random.randint(0, ep_size)
|
||||
expert_map = None
|
||||
n_local_expert = n_expert
|
||||
if (ep_size != 1):
|
||||
n_local_expert, expert_map = determine_expert_map(
|
||||
ep_size, ep_rank, n_expert)
|
||||
if ep_size != 1:
|
||||
n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert)
|
||||
expert_map = expert_map.cuda()
|
||||
start_expert = n_local_expert * ep_rank
|
||||
current_platform.seed_everything(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
|
||||
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, False)
|
||||
(gold_permuted_hidden_states, gold_expert_first_token_offset,
|
||||
gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices,
|
||||
valid_row_idx) = torch_permute(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
# token_expert_indices,
|
||||
topk,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
start_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert)
|
||||
hidden_states, gating_output, topk, False
|
||||
)
|
||||
(
|
||||
gold_permuted_hidden_states,
|
||||
gold_expert_first_token_offset,
|
||||
gold_inv_permuted_idx,
|
||||
gold_permuted_idx,
|
||||
gold_m_indices,
|
||||
valid_row_idx,
|
||||
) = torch_permute(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
# token_expert_indices,
|
||||
topk,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
start_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert,
|
||||
)
|
||||
|
||||
(permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx,
|
||||
m_indices) = moe_permute(hidden_states=hidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=n_expert,
|
||||
n_local_expert=n_local_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert)
|
||||
(
|
||||
permuted_hidden_states,
|
||||
_,
|
||||
expert_first_token_offset,
|
||||
inv_permuted_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
hidden_states=hidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=n_expert,
|
||||
n_local_expert=n_local_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert,
|
||||
)
|
||||
|
||||
# check expert_first_token_offset
|
||||
torch.testing.assert_close(gold_expert_first_token_offset,
|
||||
expert_first_token_offset,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(
|
||||
gold_expert_first_token_offset, expert_first_token_offset, atol=0, rtol=0
|
||||
)
|
||||
# check src_row_id2dst_row_id_map
|
||||
torch.testing.assert_close(gold_inv_permuted_idx.flatten(),
|
||||
inv_permuted_idx,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(
|
||||
gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
|
||||
)
|
||||
# check mindice
|
||||
# current kernel usage assumes deepgemm requires align_block_size
|
||||
# when it's not provided then we don't compute m_indices (for cutlass)
|
||||
@@ -244,19 +280,28 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||
|
||||
# check permuted_hidden_states, only valid token
|
||||
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
|
||||
permuted_hidden_states[valid_row_idx],
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(
|
||||
gold_permuted_hidden_states[valid_row_idx],
|
||||
permuted_hidden_states[valid_row_idx],
|
||||
atol=0,
|
||||
rtol=0,
|
||||
)
|
||||
# add a random tensor to simulate group gemm
|
||||
result0 = 0.5 * permuted_hidden_states + torch.randn_like(
|
||||
permuted_hidden_states)
|
||||
result0 = 0.5 * permuted_hidden_states + torch.randn_like(permuted_hidden_states)
|
||||
result4 = torch.empty_like(hidden_states)
|
||||
moe_unpermute(result4, result0, topk_weights, inv_permuted_idx,
|
||||
expert_first_token_offset)
|
||||
moe_unpermute(
|
||||
result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset
|
||||
)
|
||||
|
||||
gold4 = torch_unpermute(result0, topk_weights, topk_ids,
|
||||
token_expert_indices, inv_permuted_idx,
|
||||
valid_row_idx, topk, n_local_expert)
|
||||
gold4 = torch_unpermute(
|
||||
result0,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
inv_permuted_idx,
|
||||
valid_row_idx,
|
||||
topk,
|
||||
n_local_expert,
|
||||
)
|
||||
# check unpermuted hidden
|
||||
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)
|
||||
|
||||
@@ -11,27 +11,39 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||
QuarkLinearMethod, QuarkW4A4MXFP4)
|
||||
QuarkLinearMethod,
|
||||
QuarkW4A4MXFP4,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
QuarkW4A4MXFp4MoEMethod)
|
||||
QuarkW4A4MXFp4MoEMethod,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
importlib.metadata.version("amd-quark")
|
||||
) >= version.parse("0.8.99")
|
||||
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
|
||||
) and current_platform.is_device_capability(100)
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability(100)
|
||||
)
|
||||
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and has_flashinfer())
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and has_flashinfer()
|
||||
)
|
||||
|
||||
if TRTLLM_GEN_MXFP4_AVAILABLE:
|
||||
from flashinfer import (fp4_quantize, mxfp8_quantize,
|
||||
next_positive_power_of_2,
|
||||
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
||||
from flashinfer import (
|
||||
fp4_quantize,
|
||||
mxfp8_quantize,
|
||||
next_positive_power_of_2,
|
||||
reorder_rows_for_gated_act_gemm,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
trtllm_fp4_block_scale_moe,
|
||||
)
|
||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
|
||||
|
||||
@@ -48,21 +60,25 @@ def enable_pickle(monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_case', [
|
||||
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
|
||||
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
||||
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1)
|
||||
])
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
|
||||
reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.parametrize(
|
||||
"model_case",
|
||||
[
|
||||
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
|
||||
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
|
||||
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
if torch.cuda.device_count() < model_case.tp:
|
||||
pytest.skip(f"This test requires >={model_case.tp} gpus, got only "
|
||||
f"{torch.cuda.device_count()}")
|
||||
pytest.skip(
|
||||
f"This test requires >={model_case.tp} gpus, got only "
|
||||
f"{torch.cuda.device_count()}"
|
||||
)
|
||||
|
||||
with vllm_runner(model_case.model_id,
|
||||
tensor_parallel_size=model_case.tp,
|
||||
load_format="dummy") as llm:
|
||||
with vllm_runner(
|
||||
model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy"
|
||||
) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
@@ -72,21 +88,16 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
|
||||
|
||||
assert isinstance(layer.mlp.experts.quant_method,
|
||||
QuarkW4A4MXFp4MoEMethod)
|
||||
assert isinstance(layer.mlp.experts.quant_method, QuarkW4A4MXFp4MoEMethod)
|
||||
|
||||
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Today I am in the French Alps and",
|
||||
max_tokens=20)
|
||||
output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
def swiglu(x,
|
||||
alpha: float = 1.702,
|
||||
beta: float = 1.0,
|
||||
limit: Optional[float] = None):
|
||||
def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
@@ -96,24 +107,19 @@ def swiglu(x,
|
||||
return out_glu * (x_linear + beta)
|
||||
|
||||
|
||||
fp4_lookup_table = [
|
||||
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6
|
||||
]
|
||||
fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6]
|
||||
|
||||
|
||||
def mxfp4_dequantize(x, scale):
|
||||
assert x.dtype == torch.uint8
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
x_unpacked = torch.zeros(*x.shape[:-1],
|
||||
x.shape[-1] * 2,
|
||||
dtype=torch.int32,
|
||||
device=x.device)
|
||||
x_unpacked = torch.zeros(
|
||||
*x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device
|
||||
)
|
||||
x_unpacked[..., 0::2].copy_(x & 0xF)
|
||||
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
|
||||
|
||||
x_float = torch.zeros(x_unpacked.shape,
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device)
|
||||
for i, val in enumerate(fp4_lookup_table):
|
||||
x_float[x_unpacked == i] = val
|
||||
|
||||
@@ -162,9 +168,10 @@ def reference_moe(
|
||||
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
||||
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
|
||||
|
||||
if act_type == 'mxfp8':
|
||||
t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16),
|
||||
is_sf_swizzled_layout=False)
|
||||
if act_type == "mxfp8":
|
||||
t_quantized, t_scale = mxfp8_quantize(
|
||||
t.to(torch.bfloat16), is_sf_swizzled_layout=False
|
||||
)
|
||||
t = mxfp8_dequantize(t_quantized, t_scale)
|
||||
# MLP #2
|
||||
mlp2_weight = w2[expert_indices, ...]
|
||||
@@ -221,37 +228,53 @@ def tg_mxfp4_moe(
|
||||
transpose_optimized: bool = False,
|
||||
) -> torch.Tensor:
|
||||
sf_block_size = 32
|
||||
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
|
||||
and w13_weight.shape[1] == intermediate_size * 2
|
||||
and w13_weight.shape[2] == hidden_size // 2)
|
||||
assert (w13_weight_scale.dim() == 3
|
||||
and w13_weight_scale.shape[0] == num_experts
|
||||
and w13_weight_scale.shape[1] == intermediate_size * 2
|
||||
and w13_weight_scale.shape[2] == hidden_size // sf_block_size)
|
||||
assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts
|
||||
and w2_weight.shape[1] == hidden_size
|
||||
and w2_weight.shape[2] == intermediate_size // 2)
|
||||
assert (w2_weight_scale.dim() == 3
|
||||
and w2_weight_scale.shape[1] == hidden_size
|
||||
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size)
|
||||
assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts
|
||||
and w13_bias.shape[1] == intermediate_size * 2)
|
||||
assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts
|
||||
and w2_bias.shape[1] == hidden_size)
|
||||
assert (
|
||||
w13_weight.dim() == 3
|
||||
and w13_weight.shape[0] == num_experts
|
||||
and w13_weight.shape[1] == intermediate_size * 2
|
||||
and w13_weight.shape[2] == hidden_size // 2
|
||||
)
|
||||
assert (
|
||||
w13_weight_scale.dim() == 3
|
||||
and w13_weight_scale.shape[0] == num_experts
|
||||
and w13_weight_scale.shape[1] == intermediate_size * 2
|
||||
and w13_weight_scale.shape[2] == hidden_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
w2_weight.dim() == 3
|
||||
and w2_weight.shape[0] == num_experts
|
||||
and w2_weight.shape[1] == hidden_size
|
||||
and w2_weight.shape[2] == intermediate_size // 2
|
||||
)
|
||||
assert (
|
||||
w2_weight_scale.dim() == 3
|
||||
and w2_weight_scale.shape[1] == hidden_size
|
||||
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
w13_bias.dim() == 2
|
||||
and w13_bias.shape[0] == num_experts
|
||||
and w13_bias.shape[1] == intermediate_size * 2
|
||||
)
|
||||
assert (
|
||||
w2_bias.dim() == 2
|
||||
and w2_bias.shape[0] == num_experts
|
||||
and w2_bias.shape[1] == hidden_size
|
||||
)
|
||||
|
||||
# Swap w1 and w3 as the definition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
w13_weight_scale_ = w13_weight_scale.clone()
|
||||
w13_weight_ = w13_weight.clone()
|
||||
w13_bias_ = w13_bias.clone()
|
||||
w13_weight[:, :intermediate_size, :].copy_(
|
||||
w13_weight_[:, intermediate_size:, :])
|
||||
w13_weight[:, intermediate_size:, :].copy_(
|
||||
w13_weight_[:, :intermediate_size, :])
|
||||
w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :])
|
||||
w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :])
|
||||
w13_weight_scale[:, :intermediate_size, :].copy_(
|
||||
w13_weight_scale_[:, intermediate_size:, :])
|
||||
w13_weight_scale_[:, intermediate_size:, :]
|
||||
)
|
||||
w13_weight_scale[:, intermediate_size:, :].copy_(
|
||||
w13_weight_scale_[:, :intermediate_size, :])
|
||||
w13_weight_scale_[:, :intermediate_size, :]
|
||||
)
|
||||
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
|
||||
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
|
||||
|
||||
@@ -261,18 +284,23 @@ def tg_mxfp4_moe(
|
||||
w13_bias_interleaved = []
|
||||
for i in range(num_experts):
|
||||
w13_weight_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight[i].clone()))
|
||||
reorder_rows_for_gated_act_gemm(w13_weight[i].clone())
|
||||
)
|
||||
w13_weight_scale_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()))
|
||||
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())
|
||||
)
|
||||
w13_bias_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1,
|
||||
1)))
|
||||
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1))
|
||||
)
|
||||
w13_weight = torch.stack(w13_weight_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2)
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2
|
||||
)
|
||||
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 32)
|
||||
num_experts, 2 * intermediate_size, hidden_size // 32
|
||||
)
|
||||
w13_bias = torch.stack(w13_bias_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size)
|
||||
num_experts, 2 * intermediate_size
|
||||
)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_shuffled = []
|
||||
@@ -291,9 +319,11 @@ def tg_mxfp4_moe(
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_weights_shuffled.append(w13_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w13_weight.device)].contiguous())
|
||||
gemm1_weights_shuffled.append(
|
||||
w13_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w13 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
@@ -302,26 +332,35 @@ def tg_mxfp4_moe(
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(w13_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w13_weight_scale.device)].contiguous()))
|
||||
nvfp4_block_scale_interleave(
|
||||
w13_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w13 bias shuffling
|
||||
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm1_bias_shuffled.append(w13_bias[i].clone().reshape(
|
||||
-1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous())
|
||||
gemm1_bias_shuffled.append(
|
||||
w13_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_weights_shuffled.append(w2_weight[i].view(
|
||||
torch.uint8)[permute_indices.to(
|
||||
w2_weight.device)].contiguous())
|
||||
gemm2_weights_shuffled.append(
|
||||
w2_weight[i]
|
||||
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
|
||||
.contiguous()
|
||||
)
|
||||
# w2 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
@@ -330,48 +369,65 @@ def tg_mxfp4_moe(
|
||||
num_elts_per_sf=16,
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
nvfp4_block_scale_interleave(w2_weight_scale[i].view(
|
||||
torch.uint8)[permute_sf_indices.to(
|
||||
w2_weight_scale.device)].contiguous()))
|
||||
nvfp4_block_scale_interleave(
|
||||
w2_weight_scale[i]
|
||||
.view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)]
|
||||
.contiguous()
|
||||
)
|
||||
)
|
||||
# w2 bias shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
_cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
gemm2_bias_shuffled.append(w2_bias[i].clone().reshape(
|
||||
-1, 1)[permute_indices.to(w2_bias.device)].contiguous())
|
||||
gemm2_bias_shuffled.append(
|
||||
w2_bias[i]
|
||||
.clone()
|
||||
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
else:
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm1_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
shuffle_matrix_sf_a(
|
||||
w13_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
|
||||
gemm2_weights_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
|
||||
)
|
||||
gemm2_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
shuffle_matrix_sf_a(
|
||||
w2_weight_scale[i].view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
)
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)
|
||||
)
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_shuffled)
|
||||
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
|
||||
num_experts, 2 * intermediate_size,
|
||||
hidden_size // sf_block_size).view(torch.float8_e4m3fn)
|
||||
w13_weight_scale = (
|
||||
torch.stack(gemm1_scales_shuffled)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_shuffled)
|
||||
w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape(
|
||||
num_experts, hidden_size,
|
||||
intermediate_size // sf_block_size).view(torch.float8_e4m3fn)
|
||||
w2_weight_scale = (
|
||||
torch.stack(gemm2_scales_shuffled)
|
||||
.reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
tg_result = trtllm_fp4_block_scale_moe(
|
||||
@@ -401,7 +457,8 @@ def tg_mxfp4_moe(
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
|
||||
routing_method_type=1, # renormalize
|
||||
do_finalize=True)[0]
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
return tg_result
|
||||
|
||||
|
||||
@@ -424,20 +481,21 @@ def check_accuracy(a, b, atol, rtol, percent):
|
||||
if mismatch_percent > 1 - percent:
|
||||
raise Exception(
|
||||
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
||||
f"(threshold: {1-percent:.4f})")
|
||||
f"(threshold: {1 - percent:.4f})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"])
|
||||
@pytest.mark.parametrize("transpose_optimized", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not TRTLLM_GEN_MXFP4_AVAILABLE,
|
||||
reason="nvidia gpu and compute capability sm100 is required for this test")
|
||||
reason="nvidia gpu and compute capability sm100 is required for this test",
|
||||
)
|
||||
def test_trtllm_gen_mxfp4_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
@@ -452,45 +510,52 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
):
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16)
|
||||
w13 = (torch.randn(num_experts,
|
||||
intermediate_size * 2,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16))
|
||||
w2 = (torch.randn(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16))
|
||||
bias13 = torch.randn(num_experts, intermediate_size * 2,
|
||||
device="cuda:0") * 10
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16
|
||||
)
|
||||
w13 = torch.randn(
|
||||
num_experts,
|
||||
intermediate_size * 2,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10
|
||||
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
|
||||
router_logits = torch.rand(num_tokens, num_experts,
|
||||
dtype=torch.float32).cuda()
|
||||
router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda()
|
||||
|
||||
w13, w13_scale = fp4_quantize(w13,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False)
|
||||
w13, w13_scale = fp4_quantize(
|
||||
w13,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, intermediate_size * 2, hidden_size // 32)
|
||||
w2, w2_scale = fp4_quantize(w2,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False)
|
||||
num_experts, intermediate_size * 2, hidden_size // 32
|
||||
)
|
||||
w2, w2_scale = fp4_quantize(
|
||||
w2,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 32)
|
||||
if act_type == 'mxfp8':
|
||||
num_experts, hidden_size, intermediate_size // 32
|
||||
)
|
||||
if act_type == "mxfp8":
|
||||
hidden_states, hidden_states_scale = mxfp8_quantize(
|
||||
hidden_states, is_sf_swizzled_layout=False)
|
||||
hidden_states_scale = hidden_states_scale.view(
|
||||
torch.float8_e4m3fn).reshape(-1)
|
||||
hidden_states, is_sf_swizzled_layout=False
|
||||
)
|
||||
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
else:
|
||||
hidden_states_scale = None
|
||||
|
||||
@@ -500,9 +565,10 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
|
||||
bias13_ref = bias13
|
||||
bias2_ref = bias2
|
||||
if act_type == 'mxfp8':
|
||||
hidden_states_ref = mxfp8_dequantize(
|
||||
hidden_states, hidden_states_scale).to(torch.float32)
|
||||
if act_type == "mxfp8":
|
||||
hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to(
|
||||
torch.float32
|
||||
)
|
||||
else:
|
||||
hidden_states_ref = hidden_states.to(torch.float32)
|
||||
# Process tokens in chunks of 32 to reduce memory usage
|
||||
@@ -529,29 +595,31 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
||||
|
||||
# trtllm-gen result
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
|
||||
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
|
||||
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
|
||||
tg_result = tg_mxfp4_moe(router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13,
|
||||
w13_scale,
|
||||
bias13,
|
||||
w2,
|
||||
w2_scale,
|
||||
bias2,
|
||||
act_type,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
limit=limit,
|
||||
transpose_optimized=transpose_optimized)
|
||||
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
tg_result = tg_mxfp4_moe(
|
||||
router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13,
|
||||
w13_scale,
|
||||
bias13,
|
||||
w2,
|
||||
w2_scale,
|
||||
bias2,
|
||||
act_type,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
limit=limit,
|
||||
transpose_optimized=transpose_optimized,
|
||||
)
|
||||
# relatively loose check since the mxfp4 quantization is less accurate
|
||||
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
@@ -573,8 +641,7 @@ def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor:
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not HOPPER_MXFP4_BF16_AVAILABLE,
|
||||
reason="nvidia gpu sm90 and flashinfer are required for this test",
|
||||
@@ -593,52 +660,73 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
# Random MXFP4 weights and scales (uint8), contiguous [w1; w3]
|
||||
w13_q = torch.randint(
|
||||
0,
|
||||
256, (num_experts, 2 * intermediate_size, hidden_size // 2),
|
||||
256,
|
||||
(num_experts, 2 * intermediate_size, hidden_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w13_scale = torch.randint(
|
||||
118,
|
||||
123, (num_experts, 2 * intermediate_size, hidden_size // 32),
|
||||
123,
|
||||
(num_experts, 2 * intermediate_size, hidden_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
w2_q = torch.randint(0,
|
||||
256,
|
||||
(num_experts, hidden_size, intermediate_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
w2_q = torch.randint(
|
||||
0,
|
||||
256,
|
||||
(num_experts, hidden_size, intermediate_size // 2),
|
||||
device=device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
w2_scale = torch.randint(
|
||||
118,
|
||||
123, (num_experts, hidden_size, intermediate_size // 32),
|
||||
123,
|
||||
(num_experts, hidden_size, intermediate_size // 32),
|
||||
device=device,
|
||||
dtype=torch.uint8)
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) * 10)
|
||||
bias2 = (torch.randn(
|
||||
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
|
||||
router_logits = torch.rand(num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
bias13 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
* 10
|
||||
)
|
||||
bias2 = (
|
||||
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
||||
)
|
||||
router_logits = torch.rand(
|
||||
num_tokens, num_experts, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size)
|
||||
num_experts, 2 * intermediate_size, hidden_size
|
||||
)
|
||||
w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape(
|
||||
num_experts, hidden_size, intermediate_size)
|
||||
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
|
||||
hidden_states.to(torch.float32), w13_ref,
|
||||
bias13.to(torch.float32), w2_ref,
|
||||
bias2.to(torch.float32), alpha, beta, limit, 'bf16')
|
||||
num_experts, hidden_size, intermediate_size
|
||||
)
|
||||
ref = reference_moe(
|
||||
router_logits.to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states.to(torch.float32),
|
||||
w13_ref,
|
||||
bias13.to(torch.float32),
|
||||
w2_ref,
|
||||
bias2.to(torch.float32),
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
"bf16",
|
||||
)
|
||||
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
@@ -654,23 +742,24 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
w13_s_inter = _interleave_scales_lastdim_by4(w13_s)
|
||||
w2_s_inter = _interleave_scales_lastdim_by4(w2_scale)
|
||||
|
||||
routing_weights = torch.nn.functional.softmax(router_logits,
|
||||
dim=1,
|
||||
dtype=torch.float32)
|
||||
token_final_scales, token_selected_experts = torch.topk(routing_weights,
|
||||
topk,
|
||||
dim=-1)
|
||||
token_final_scales = (token_final_scales /
|
||||
token_final_scales.sum(dim=-1, keepdim=True))
|
||||
routing_weights = torch.nn.functional.softmax(
|
||||
router_logits, dim=1, dtype=torch.float32
|
||||
)
|
||||
token_final_scales, token_selected_experts = torch.topk(
|
||||
routing_weights, topk, dim=-1
|
||||
)
|
||||
token_final_scales = token_final_scales / token_final_scales.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
|
||||
alpha = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
|
||||
beta = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
|
||||
limit = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=hidden_states,
|
||||
@@ -680,8 +769,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
fc2_expert_weights=w2_q,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=out,
|
||||
quant_scales=[w13_s_inter.to(torch.uint8),
|
||||
w2_s_inter.to(torch.uint8)],
|
||||
quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)],
|
||||
fc1_expert_biases=w13_b,
|
||||
fc2_expert_biases=bias2.to(torch.bfloat16),
|
||||
swiglu_alpha=alpha,
|
||||
@@ -702,11 +790,13 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)])
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100) and has_flashinfer()),
|
||||
not (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and has_flashinfer()
|
||||
),
|
||||
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
||||
)
|
||||
def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
@@ -723,32 +813,43 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
device = "cuda:0"
|
||||
|
||||
# Inputs
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
hidden_states = torch.randn(
|
||||
num_tokens, hidden_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
# Float weights in w13 format [w1; w3]
|
||||
w13 = (torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) / 10)
|
||||
w2 = (torch.randn(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) / 10)
|
||||
w13 = (
|
||||
torch.randn(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
w2 = (
|
||||
torch.randn(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
# Bias contiguous [b1; b3]
|
||||
bias13 = (torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
device=device,
|
||||
dtype=torch.bfloat16) * 10)
|
||||
bias2 = (torch.randn(
|
||||
num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10)
|
||||
router_logits = torch.rand(num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
bias13 = (
|
||||
torch.randn(
|
||||
num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
* 10
|
||||
)
|
||||
bias2 = (
|
||||
torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10
|
||||
)
|
||||
router_logits = torch.rand(
|
||||
num_tokens, num_experts, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# Quantize weights to MXFP4 per expert (SM100 path)
|
||||
from flashinfer import mxfp4_quantize
|
||||
@@ -761,36 +862,56 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
sfs.append(sf)
|
||||
return torch.stack(qs), torch.stack(sfs)
|
||||
|
||||
def dequant_mxfp4_batches(mat_fp4: torch.Tensor,
|
||||
scale_tensor: torch.Tensor):
|
||||
def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
|
||||
num_batches = mat_fp4.size(0)
|
||||
scale_tensor = scale_tensor.view(num_batches, -1)
|
||||
from flashinfer import mxfp4_dequantize
|
||||
return torch.stack([
|
||||
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
|
||||
for b in range(num_batches)
|
||||
])
|
||||
|
||||
return torch.stack(
|
||||
[
|
||||
mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :])
|
||||
for b in range(num_batches)
|
||||
]
|
||||
)
|
||||
|
||||
w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts)
|
||||
w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts)
|
||||
|
||||
# Reference result using dequantized tensors and reference_moe
|
||||
w13_ref = dequant_mxfp4_batches(
|
||||
w13_q.view(torch.uint8),
|
||||
w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size).to(device)
|
||||
w2_ref = dequant_mxfp4_batches(
|
||||
w2_q.view(torch.uint8),
|
||||
w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape(
|
||||
num_experts, hidden_size, intermediate_size).to(device)
|
||||
w13_ref = (
|
||||
dequant_mxfp4_batches(
|
||||
w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1)
|
||||
)
|
||||
.to(torch.float32)
|
||||
.reshape(num_experts, 2 * intermediate_size, hidden_size)
|
||||
.to(device)
|
||||
)
|
||||
w2_ref = (
|
||||
dequant_mxfp4_batches(
|
||||
w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1)
|
||||
)
|
||||
.to(torch.float32)
|
||||
.reshape(num_experts, hidden_size, intermediate_size)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
# Quantize activations for SM100 path and dequantize for reference
|
||||
hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32)
|
||||
# Reference uses BF16 input but quantizes intermediate activation to MXFP8
|
||||
ref = reference_moe(router_logits.to(torch.float32), topk, num_experts,
|
||||
hidden_states.to(torch.float32), w13_ref,
|
||||
bias13.to(torch.float32), w2_ref,
|
||||
bias2.to(torch.float32), alpha, beta, limit, 'mxfp8')
|
||||
ref = reference_moe(
|
||||
router_logits.to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states.to(torch.float32),
|
||||
w13_ref,
|
||||
bias13.to(torch.float32),
|
||||
w2_ref,
|
||||
bias2.to(torch.float32),
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
"mxfp8",
|
||||
)
|
||||
|
||||
# Prepare inputs for FlashInfer CUTLASS fused MoE
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
@@ -807,31 +928,28 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
|
||||
w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
|
||||
|
||||
# Build routing for kernel
|
||||
routing_weights = torch.nn.functional.softmax(router_logits,
|
||||
dim=1,
|
||||
dtype=torch.float32)
|
||||
token_final_scales, token_selected_experts = torch.topk(routing_weights,
|
||||
topk,
|
||||
dim=-1)
|
||||
token_final_scales = (token_final_scales /
|
||||
token_final_scales.sum(dim=-1, keepdim=True))
|
||||
routing_weights = torch.nn.functional.softmax(
|
||||
router_logits, dim=1, dtype=torch.float32
|
||||
)
|
||||
token_final_scales, token_selected_experts = torch.topk(
|
||||
routing_weights, topk, dim=-1
|
||||
)
|
||||
token_final_scales = token_final_scales / token_final_scales.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
token_selected_experts = token_selected_experts.to(torch.int).contiguous()
|
||||
|
||||
out = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
if alpha is not None:
|
||||
alpha_t = torch.full((num_experts, ),
|
||||
alpha,
|
||||
device=hidden_states.device)
|
||||
alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device)
|
||||
else:
|
||||
alpha_t = None
|
||||
if beta is not None:
|
||||
beta_t = torch.full((num_experts, ), beta, device=hidden_states.device)
|
||||
beta_t = torch.full((num_experts,), beta, device=hidden_states.device)
|
||||
else:
|
||||
beta_t = None
|
||||
if limit is not None:
|
||||
limit_t = torch.full((num_experts, ),
|
||||
limit,
|
||||
device=hidden_states.device)
|
||||
limit_t = torch.full((num_experts,), limit, device=hidden_states.device)
|
||||
else:
|
||||
limit_t = None
|
||||
|
||||
|
||||
@@ -4,9 +4,11 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
@@ -16,8 +18,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(
|
||||
"Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True
|
||||
)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
@@ -38,36 +41,34 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
dtype: torch.dtype):
|
||||
def test_cutlass_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
quant_blocksize = 16
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
(_, w1_q, w1_blockscale,
|
||||
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_out_ch_quant=False,
|
||||
)
|
||||
(_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = (
|
||||
make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_out_ch_quant=False,
|
||||
)
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
@@ -97,40 +98,44 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize)
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4,
|
||||
a_scale_interleaved,
|
||||
a_global_scale,
|
||||
dtype=a.dtype,
|
||||
device=a.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
torch.testing.assert_close(torch_output,
|
||||
cutlass_output,
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -9,13 +9,10 @@ import torch
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8)
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@@ -24,9 +21,13 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
try:
|
||||
from pplx_kernels import AllToAll
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_finalize, nvshmem_get_unique_id,
|
||||
nvshmem_init)
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_finalize,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
has_pplx = True
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
@@ -50,12 +51,12 @@ def chunk_by_rank(t, r, w):
|
||||
chunk = rank_chunk(num, r, w)
|
||||
rem = num % w
|
||||
if rem == 0 or r < rem:
|
||||
return t[(r * chunk):(r + 1) * chunk].contiguous()
|
||||
return t[(r * chunk) : (r + 1) * chunk].contiguous()
|
||||
else:
|
||||
long_chunks = (num // w + 1) * rem
|
||||
short_chunks = (r - rem) * chunk
|
||||
start = long_chunks + short_chunks
|
||||
return t[start:start + chunk].contiguous()
|
||||
return t[start : start + chunk].contiguous()
|
||||
|
||||
|
||||
def pplx_cutlass_moe(
|
||||
@@ -75,7 +76,9 @@ def pplx_cutlass_moe(
|
||||
group_name: Optional[str],
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
PplxPrepareAndFinalize,
|
||||
)
|
||||
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
@@ -126,35 +129,40 @@ def pplx_cutlass_moe(
|
||||
ata,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_local_experts=num_local_experts,
|
||||
num_dispatchers=num_dispatchers)
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_local_experts, ),
|
||||
intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_local_experts, ),
|
||||
2 * intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides1 = torch.full(
|
||||
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
ab_strides2 = torch.full(
|
||||
(num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides1 = torch.full(
|
||||
(num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides2 = torch.full(
|
||||
(num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64
|
||||
)
|
||||
|
||||
experts = CutlassBatchedExpertsFp8(
|
||||
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
|
||||
ab_strides2, c_strides1, c_strides2,
|
||||
num_local_experts,
|
||||
num_dispatchers,
|
||||
out_dtype,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||
if per_act_token else a1_scale[rank]))
|
||||
if per_act_token
|
||||
else a1_scale[rank],
|
||||
),
|
||||
)
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@@ -162,10 +170,10 @@ def pplx_cutlass_moe(
|
||||
)
|
||||
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weights, rank,
|
||||
world_size).to(device)
|
||||
chunk_topk_ids = chunk_by_rank(topk_ids, rank,
|
||||
world_size).to(torch.uint32).to(device)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device)
|
||||
chunk_topk_ids = (
|
||||
chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device)
|
||||
)
|
||||
|
||||
out = fused_cutlass_experts(
|
||||
a_chunk,
|
||||
@@ -174,7 +182,7 @@ def pplx_cutlass_moe(
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None, #TODO
|
||||
expert_map=None, # TODO
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@@ -210,35 +218,48 @@ def _pplx_moe(
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if pgi.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks,
|
||||
backend="gloo")
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_experts(a_full, w1_full, w2_full,
|
||||
topk_weights, topk_ids)
|
||||
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
|
||||
w2_scale, topk_weights, topk_ids,
|
||||
a1_scale, out_dtype, per_act_token,
|
||||
per_out_ch, group_name)
|
||||
torch_output = torch_experts(
|
||||
a_full, w1_full, w2_full, topk_weights, topk_ids
|
||||
)
|
||||
pplx_output = pplx_cutlass_moe(
|
||||
pgi,
|
||||
dp_size,
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
a1_scale,
|
||||
out_dtype,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
group_name,
|
||||
)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_output.device)
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
|
||||
pplx_output.device
|
||||
)
|
||||
|
||||
# Uncomment if more debugging is needed
|
||||
# print("PPLX OUT:", pplx_output)
|
||||
# print("TORCH OUT:", torch_output)
|
||||
|
||||
torch.testing.assert_close(pplx_output,
|
||||
torch_output,
|
||||
atol=0.05,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
|
||||
finally:
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
@@ -251,13 +272,15 @@ def _pplx_moe(
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) # , [4, 2]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
@requires_pplx
|
||||
def test_cutlass_moe_pplx(
|
||||
m: int,
|
||||
@@ -273,7 +296,6 @@ def test_cutlass_moe_pplx(
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
|
||||
@@ -283,22 +305,18 @@ def test_cutlass_moe_pplx(
|
||||
n_b_scales = 2 * n if per_out_ch else 1
|
||||
k_b_scales = k if per_out_ch else 1
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch
|
||||
)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch
|
||||
)
|
||||
|
||||
w1_d = torch.empty_like(w1)
|
||||
w2_d = torch.empty_like(w2)
|
||||
@@ -307,19 +325,35 @@ def test_cutlass_moe_pplx(
|
||||
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
world_size, dp_size = world_dp_size
|
||||
a_scale1 = torch.randn(
|
||||
(m if per_act_token else 1, 1), device="cuda",
|
||||
dtype=torch.float32) / 10.0
|
||||
a_scale1 = (
|
||||
torch.randn(
|
||||
(m if per_act_token else 1, 1), device="cuda", dtype=torch.float32
|
||||
)
|
||||
/ 10.0
|
||||
)
|
||||
if not per_act_token:
|
||||
a_scale1 = a_scale1.repeat(world_size, 1)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
|
||||
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
|
||||
dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
|
||||
use_internode)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_pplx_moe,
|
||||
dp_size,
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
a_scale1,
|
||||
dtype,
|
||||
a,
|
||||
w1_d,
|
||||
w2_d,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
use_internode,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/test_pplx_moe.py`.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import textwrap
|
||||
@@ -15,29 +16,34 @@ import torch
|
||||
|
||||
try:
|
||||
from pplx_kernels import AllToAll
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_finalize, nvshmem_get_unique_id,
|
||||
nvshmem_init)
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_finalize,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
has_pplx = True
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
|
||||
_set_vllm_config)
|
||||
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
|
||||
naive_batched_moe)
|
||||
from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config
|
||||
from tests.kernels.moe.utils import (
|
||||
make_shared_experts,
|
||||
make_test_weights,
|
||||
naive_batched_moe,
|
||||
)
|
||||
from tests.kernels.quant_utils import dequant
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
@@ -59,7 +65,7 @@ BATCHED_MOE_MNK_FACTORS = [
|
||||
|
||||
PPLX_COMBOS = [
|
||||
# TODO(bnell): figure out why this fails, seems to be test problem
|
||||
#(1, 128, 128),
|
||||
# (1, 128, 128),
|
||||
(2, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
(4, 128, 128),
|
||||
@@ -91,17 +97,16 @@ def torch_prepare(
|
||||
num_tokens, hidden_dim = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
tokens_per_expert = torch.bincount(topk_ids.view(-1),
|
||||
minlength=num_experts)
|
||||
tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
|
||||
|
||||
assert tokens_per_expert.numel() == num_experts
|
||||
|
||||
if max_num_tokens is None:
|
||||
max_num_tokens = int(tokens_per_expert.max().item())
|
||||
|
||||
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
|
||||
dtype=a.dtype,
|
||||
device=a.device)
|
||||
b_a = torch.zeros(
|
||||
(num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device
|
||||
)
|
||||
|
||||
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
|
||||
|
||||
@@ -109,28 +114,29 @@ def torch_prepare(
|
||||
for j in range(topk):
|
||||
expert_id = topk_ids[token, j]
|
||||
idx = token_counts[expert_id]
|
||||
b_a[expert_id, idx:idx + 1, :] = a[token, :]
|
||||
b_a[expert_id, idx : idx + 1, :] = a[token, :]
|
||||
token_counts[expert_id] = token_counts[expert_id] + 1
|
||||
|
||||
return b_a, tokens_per_expert
|
||||
|
||||
|
||||
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
def torch_finalize(
|
||||
b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
num_tokens = topk_ids.shape[0]
|
||||
num_experts = b_out.shape[0]
|
||||
K = b_out.shape[-1]
|
||||
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
|
||||
expert_counts = torch.zeros(num_experts,
|
||||
dtype=torch.int,
|
||||
device=b_out.device)
|
||||
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
|
||||
for token in range(num_tokens):
|
||||
expert_ids = topk_ids[token]
|
||||
for i in range(expert_ids.numel()):
|
||||
expert_id = expert_ids[i]
|
||||
idx = expert_counts[expert_id]
|
||||
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
|
||||
1, :] * topk_weight[token, i]
|
||||
out[token, :] = (
|
||||
out[token, :]
|
||||
+ b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i]
|
||||
)
|
||||
expert_counts[expert_id] = expert_counts[expert_id] + 1
|
||||
|
||||
return out
|
||||
@@ -149,17 +155,18 @@ def torch_batched_moe(
|
||||
num_tokens, topk = topk_ids.shape
|
||||
_, max_num_tokens, K = b_a.shape
|
||||
assert num_experts == b_a.shape[0] and w2.shape[1] == K
|
||||
out = torch.zeros((num_experts, max_num_tokens, K),
|
||||
dtype=b_a.dtype,
|
||||
device=b_a.device)
|
||||
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
|
||||
dtype=b_a.dtype,
|
||||
device=b_a.device)
|
||||
out = torch.zeros(
|
||||
(num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device
|
||||
)
|
||||
tmp = torch.empty(
|
||||
(max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device
|
||||
)
|
||||
for expert in range(num_experts):
|
||||
num = tokens_per_expert[expert]
|
||||
if num > 0:
|
||||
torch.ops._C.silu_and_mul(
|
||||
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
|
||||
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
)
|
||||
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
|
||||
|
||||
return torch_finalize(out, topk_weight, topk_ids)
|
||||
@@ -186,20 +193,16 @@ def test_fused_moe_batched_experts(
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
baseline_output = torch_experts(a, w1, w2, topk_weight,
|
||||
topk_ids) # only for baseline
|
||||
baseline_output = torch_experts(
|
||||
a, w1, w2, topk_weight, topk_ids
|
||||
) # only for baseline
|
||||
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = naive_batched_moe(
|
||||
a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this
|
||||
a, w1, w2, topk_weight, topk_ids
|
||||
) # pick torch_experts or this
|
||||
|
||||
torch.testing.assert_close(baseline_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(baseline_output,
|
||||
batched_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
def create_pplx_prepare_finalize(
|
||||
@@ -217,7 +220,9 @@ def create_pplx_prepare_finalize(
|
||||
group_name: Optional[str],
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
|
||||
PplxPrepareAndFinalize,
|
||||
pplx_hidden_dim_scale_bytes,
|
||||
)
|
||||
|
||||
max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
|
||||
num_local_experts = rank_chunk(num_experts, 0, world_size)
|
||||
@@ -266,28 +271,31 @@ def rank_chunk(num: int, r: int, w: int) -> int:
|
||||
|
||||
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
|
||||
chunk = rank_chunk(t.shape[0], r, w)
|
||||
return t[(r * chunk):(r + 1) * chunk]
|
||||
return t[(r * chunk) : (r + 1) * chunk]
|
||||
|
||||
|
||||
def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int,
|
||||
w: int) -> Optional[torch.Tensor]:
|
||||
def maybe_chunk_by_rank(
|
||||
t: Optional[torch.Tensor], r: int, w: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if t is not None:
|
||||
return chunk_by_rank(t, r, w)
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int,
|
||||
w: int) -> Optional[torch.Tensor]:
|
||||
def chunk_scales_by_rank(
|
||||
t: Optional[torch.Tensor], r: int, w: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if t is not None and t.numel() > 1:
|
||||
chunk = rank_chunk(t.shape[0], r, w)
|
||||
return t[(r * chunk):(r + 1) * chunk]
|
||||
return t[(r * chunk) : (r + 1) * chunk]
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def chunk_scales(t: Optional[torch.Tensor], start: int,
|
||||
end: int) -> Optional[torch.Tensor]:
|
||||
def chunk_scales(
|
||||
t: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if t is not None and t.numel() > 1:
|
||||
return t[start:end]
|
||||
else:
|
||||
@@ -350,8 +358,7 @@ def pplx_prepare_finalize(
|
||||
device=device,
|
||||
)
|
||||
|
||||
if (quant_dtype is not None and not per_act_token_quant
|
||||
and block_shape is None):
|
||||
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
|
||||
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
else:
|
||||
@@ -375,8 +382,7 @@ def pplx_prepare_finalize(
|
||||
),
|
||||
)
|
||||
|
||||
b_a = dummy_work(
|
||||
dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
|
||||
b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
|
||||
|
||||
prepare_finalize.finalize(
|
||||
out,
|
||||
@@ -410,15 +416,17 @@ def _pplx_prepare_finalize(
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if pgi.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
group_name = None
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks,
|
||||
backend="gloo")
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
@@ -426,22 +434,28 @@ def _pplx_prepare_finalize(
|
||||
|
||||
a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
|
||||
|
||||
torch_output = (a_rep.view(m, topk, k) *
|
||||
topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(
|
||||
dim=1)
|
||||
torch_output = (
|
||||
a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight,
|
||||
topk_ids, num_experts, quant_dtype,
|
||||
block_shape, per_act_token_quant,
|
||||
group_name)
|
||||
pplx_output = pplx_prepare_finalize(
|
||||
pgi,
|
||||
dp_size,
|
||||
a,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
num_experts,
|
||||
quant_dtype,
|
||||
block_shape,
|
||||
per_act_token_quant,
|
||||
group_name,
|
||||
)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||
pgi.world_size).to(pgi.device)
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
|
||||
pgi.device
|
||||
)
|
||||
|
||||
torch.testing.assert_close(pplx_output,
|
||||
torch_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2)
|
||||
finally:
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
@@ -491,9 +505,19 @@ def test_pplx_prepare_finalize_slow(
|
||||
a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
|
||||
score = torch.randn((m, e), device=device, dtype=act_dtype)
|
||||
|
||||
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
|
||||
topk, e, quant_dtype, block_shape, per_act_token_quant,
|
||||
use_internode)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_pplx_prepare_finalize,
|
||||
dp_size,
|
||||
a,
|
||||
score,
|
||||
topk,
|
||||
e,
|
||||
quant_dtype,
|
||||
block_shape,
|
||||
per_act_token_quant,
|
||||
use_internode,
|
||||
)
|
||||
|
||||
|
||||
def pplx_moe(
|
||||
@@ -517,7 +541,6 @@ def pplx_moe(
|
||||
use_cudagraphs: bool = True,
|
||||
shared_experts: Optional[torch.nn.Module] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
num_experts = w1.shape[0]
|
||||
topk = topk_ids.shape[1]
|
||||
@@ -579,21 +602,23 @@ def pplx_moe(
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
if use_compile:
|
||||
_fused_experts = torch.compile(fused_experts,
|
||||
backend='inductor',
|
||||
fullgraph=True)
|
||||
_fused_experts = torch.compile(
|
||||
fused_experts, backend="inductor", fullgraph=True
|
||||
)
|
||||
torch._dynamo.mark_dynamic(a_chunk, 0)
|
||||
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
|
||||
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
|
||||
else:
|
||||
_fused_experts = fused_experts
|
||||
|
||||
out = _fused_experts(a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts)
|
||||
out = _fused_experts(
|
||||
a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
)
|
||||
|
||||
if use_cudagraphs:
|
||||
if isinstance(out, tuple):
|
||||
@@ -604,12 +629,14 @@ def pplx_moe(
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = _fused_experts(a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts)
|
||||
out = _fused_experts(
|
||||
a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
@@ -640,15 +667,17 @@ def _pplx_moe(
|
||||
):
|
||||
try:
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if pgi.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
group_name = None
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks,
|
||||
backend="gloo")
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
m, k = a.shape
|
||||
@@ -666,8 +695,7 @@ def _pplx_moe(
|
||||
w1_s = w1_s.to(device) if w1_s is not None else None
|
||||
w2_s = w2_s.to(device) if w2_s is not None else None
|
||||
|
||||
if (quant_dtype is not None and not per_act_token_quant
|
||||
and block_shape is None):
|
||||
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
|
||||
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
else:
|
||||
@@ -742,31 +770,27 @@ def _pplx_moe(
|
||||
if shared_output is not None:
|
||||
assert pplx_shared_output is not None
|
||||
chunked_shared_output = chunk_by_rank(
|
||||
shared_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_shared_output.device)
|
||||
shared_output, pgi.rank, pgi.world_size
|
||||
).to(pplx_shared_output.device)
|
||||
else:
|
||||
chunked_shared_output = None
|
||||
|
||||
chunked_batch_output = chunk_by_rank(
|
||||
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)
|
||||
batched_output, pgi.rank, pgi.world_size
|
||||
).to(pplx_output.device)
|
||||
|
||||
torch.testing.assert_close(batched_output,
|
||||
torch_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
|
||||
|
||||
torch.testing.assert_close(pplx_output,
|
||||
chunked_batch_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
torch.testing.assert_close(
|
||||
pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2
|
||||
)
|
||||
|
||||
if shared_experts is not None:
|
||||
assert chunked_shared_output is not None
|
||||
assert pplx_shared_output is not None
|
||||
torch.testing.assert_close(pplx_shared_output,
|
||||
chunked_shared_output,
|
||||
atol=3e-2,
|
||||
rtol=3e-2)
|
||||
torch.testing.assert_close(
|
||||
pplx_shared_output, chunked_shared_output, atol=3e-2, rtol=3e-2
|
||||
)
|
||||
|
||||
finally:
|
||||
if use_internode:
|
||||
@@ -823,15 +847,33 @@ def test_pplx_moe_slow(
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
|
||||
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
|
||||
use_internode)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_pplx_moe,
|
||||
dp_size,
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
e,
|
||||
w1_s,
|
||||
w2_s,
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
block_shape,
|
||||
use_internode,
|
||||
)
|
||||
|
||||
|
||||
def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
use_shared_experts: bool, make_weights: bool,
|
||||
test_fn: Callable):
|
||||
|
||||
def _pplx_test_loop(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
use_internode: bool,
|
||||
use_shared_experts: bool,
|
||||
make_weights: bool,
|
||||
test_fn: Callable,
|
||||
):
|
||||
def format_result(msg, ex=None):
|
||||
if ex is not None:
|
||||
x = str(ex)
|
||||
@@ -850,12 +892,12 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
new_vllm_config = copy.deepcopy(vllm_config)
|
||||
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
|
||||
new_vllm_config.parallel_config.enable_expert_parallel = True
|
||||
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
|
||||
pgi.local_rank)
|
||||
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, pgi.local_rank)
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
|
||||
[False, True], [None, [128, 128]])
|
||||
combos = itertools.product(
|
||||
PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]
|
||||
)
|
||||
exceptions = []
|
||||
count = 0
|
||||
for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
|
||||
@@ -873,13 +915,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
|
||||
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
|
||||
f"block_shape={block_shape}, use_internode={use_internode}, "
|
||||
f"use_shared_experts={use_shared_experts}")
|
||||
f"use_shared_experts={use_shared_experts}"
|
||||
)
|
||||
|
||||
if not use_fp8_w8a8 and (per_act_token_quant
|
||||
or block_shape is not None):
|
||||
print(
|
||||
f"{test_desc} - Skip quantization test for non-quantized type."
|
||||
)
|
||||
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
|
||||
print(f"{test_desc} - Skip quantization test for non-quantized type.")
|
||||
continue
|
||||
|
||||
if per_act_token_quant and block_shape is not None:
|
||||
@@ -934,10 +974,10 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
if len(exceptions) > 0:
|
||||
raise RuntimeError(
|
||||
f"{len(exceptions)} of {count} tests failed in child process, "
|
||||
f"rank={pgi.rank}.")
|
||||
f"rank={pgi.rank}."
|
||||
)
|
||||
else:
|
||||
print(f"{count} of {count} tests passed in child process, "
|
||||
f"rank={pgi.rank}.")
|
||||
print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@@ -950,8 +990,15 @@ def test_pplx_prepare_finalize(
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
|
||||
use_internode, False, False, _pplx_prepare_finalize)
|
||||
parallel_launch(
|
||||
world_size * dp_size,
|
||||
_pplx_test_loop,
|
||||
dp_size,
|
||||
use_internode,
|
||||
False,
|
||||
False,
|
||||
_pplx_prepare_finalize,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@@ -966,5 +1013,12 @@ def test_pplx_moe(
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
|
||||
use_shared_experts, True, _pplx_moe)
|
||||
parallel_launch(
|
||||
world_size,
|
||||
_pplx_test_loop,
|
||||
dp_size,
|
||||
use_internode,
|
||||
use_shared_experts,
|
||||
True,
|
||||
_pplx_moe,
|
||||
)
|
||||
|
||||
@@ -24,13 +24,14 @@ aiter_available = importlib.util.find_spec("aiter") is not None
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and aiter_available),
|
||||
reason="AITER ops are only available on ROCm with aiter package installed")
|
||||
reason="AITER ops are only available on ROCm with aiter package installed",
|
||||
)
|
||||
|
||||
|
||||
def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
|
||||
"""Test that the custom op is correctly registered."""
|
||||
# Check if the op exists in torch.ops.vllm
|
||||
assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk')
|
||||
assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk")
|
||||
|
||||
# Check if the op is callable
|
||||
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
|
||||
@@ -39,7 +40,7 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
|
||||
def test_rocm_aiter_grouped_topk_custom_op_registration():
|
||||
"""Test that the custom op is correctly registered."""
|
||||
# Check if the op exists in torch.ops.vllm
|
||||
assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk')
|
||||
assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk")
|
||||
|
||||
# Check if the op is callable
|
||||
assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)
|
||||
@@ -56,25 +57,29 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
||||
renormalize = True
|
||||
scale_factor = 1.0
|
||||
|
||||
gating_output = torch.randn((token, expert),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
e_score_correction_bias = torch.randn((expert, ),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
|
||||
e_score_correction_bias = torch.randn(
|
||||
(expert,), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
device = gating_output.device
|
||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
topk_weights = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||
|
||||
# Define a function that uses the op
|
||||
def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
|
||||
topk_weights, topk_ids):
|
||||
def biased_grouped_topk_fn(
|
||||
gating_output, e_score_correction_bias, topk_weights, topk_ids
|
||||
):
|
||||
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||
gating_output, e_score_correction_bias, topk_weights, topk_ids,
|
||||
num_expert_group, topk_group, renormalize, scale_factor)
|
||||
gating_output,
|
||||
e_score_correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
renormalize,
|
||||
scale_factor,
|
||||
)
|
||||
|
||||
# Verify the op's fake implementation
|
||||
torch.library.opcheck(
|
||||
@@ -84,51 +89,49 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
|
||||
"num_expert_group": num_expert_group,
|
||||
"topk_group": topk_group,
|
||||
"need_renorm": renormalize,
|
||||
"routed_scaling_factor": scale_factor
|
||||
"routed_scaling_factor": scale_factor,
|
||||
},
|
||||
test_utils=("test_faketensor"))
|
||||
test_utils=("test_faketensor"),
|
||||
)
|
||||
|
||||
# Compile the function with appropriate settings
|
||||
compiled_fn = torch.compile(biased_grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False)
|
||||
compiled_fn = torch.compile(
|
||||
biased_grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
topk_weights_original = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_ids_original = torch.empty((token, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
topk_weights_original = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
topk_weights_compiled = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_ids_compiled = torch.empty((token, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
topk_weights_compiled = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
|
||||
biased_grouped_topk_fn(gating_output, e_score_correction_bias,
|
||||
topk_weights_original, topk_ids_original)
|
||||
compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled,
|
||||
topk_ids_compiled)
|
||||
biased_grouped_topk_fn(
|
||||
gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original
|
||||
)
|
||||
compiled_fn(
|
||||
gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled
|
||||
)
|
||||
|
||||
# Sort the results for comparison since the order might not be deterministic
|
||||
topk_ids_original, indices_original = torch.sort(topk_ids_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1,
|
||||
indices_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
|
||||
|
||||
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
|
||||
indices_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
|
||||
|
||||
# Verify results match
|
||||
assert torch.allclose(topk_weights_original,
|
||||
topk_weights_compiled,
|
||||
rtol=1e-2,
|
||||
atol=1e-2)
|
||||
assert torch.allclose(
|
||||
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||
|
||||
|
||||
@@ -144,73 +147,73 @@ def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
|
||||
scoring_func = "softmax"
|
||||
scale_factor = 1.0
|
||||
|
||||
gating_output = torch.randn((token, expert),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
device = gating_output.device
|
||||
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
topk_weights = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||
|
||||
# Define a function that uses the op
|
||||
def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
|
||||
return torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||
gating_output, topk_weights, topk_ids, num_expert_group,
|
||||
topk_group, renormalize, scoring_func, scale_factor)
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
renormalize,
|
||||
scoring_func,
|
||||
scale_factor,
|
||||
)
|
||||
|
||||
# Verify the op's fake implementation
|
||||
torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk,
|
||||
(gating_output, topk_weights, topk_ids),
|
||||
kwargs={
|
||||
"num_expert_group": num_expert_group,
|
||||
"topk_group": topk_group,
|
||||
"need_renorm": renormalize,
|
||||
"scoring_func": scoring_func,
|
||||
"routed_scaling_factor": scale_factor
|
||||
},
|
||||
test_utils=("test_faketensor"))
|
||||
torch.library.opcheck(
|
||||
torch.ops.vllm.rocm_aiter_grouped_topk,
|
||||
(gating_output, topk_weights, topk_ids),
|
||||
kwargs={
|
||||
"num_expert_group": num_expert_group,
|
||||
"topk_group": topk_group,
|
||||
"need_renorm": renormalize,
|
||||
"scoring_func": scoring_func,
|
||||
"routed_scaling_factor": scale_factor,
|
||||
},
|
||||
test_utils=("test_faketensor"),
|
||||
)
|
||||
|
||||
# Compile the function with appropriate settings
|
||||
compiled_fn = torch.compile(grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False)
|
||||
compiled_fn = torch.compile(
|
||||
grouped_topk_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
topk_weights_original = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_ids_original = torch.empty((token, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
topk_weights_original = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
topk_weights_compiled = torch.empty((token, topk),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
topk_ids_compiled = torch.empty((token, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
topk_weights_compiled = torch.empty(
|
||||
(token, topk), dtype=torch.float32, device=device
|
||||
)
|
||||
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||
|
||||
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
|
||||
grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original,
|
||||
scoring_func)
|
||||
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled,
|
||||
scoring_func)
|
||||
grouped_topk_fn(
|
||||
gating_output, topk_weights_original, topk_ids_original, scoring_func
|
||||
)
|
||||
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func)
|
||||
|
||||
# Sort the results for comparison since the order might not be deterministic
|
||||
topk_ids_original, indices_original = torch.sort(topk_ids_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1,
|
||||
indices_original)
|
||||
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
|
||||
|
||||
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
|
||||
indices_compiled)
|
||||
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
|
||||
|
||||
# Verify results match
|
||||
assert torch.allclose(topk_weights_original,
|
||||
topk_weights_compiled,
|
||||
rtol=1e-2,
|
||||
atol=1e-2)
|
||||
assert torch.allclose(
|
||||
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
assert torch.allclose(topk_ids_original, topk_ids_compiled)
|
||||
|
||||
@@ -5,7 +5,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm_cuda)
|
||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@@ -34,7 +35,6 @@ CASES = [
|
||||
(256, 16, 7168, fp8_dtype),
|
||||
(256, 32, 7168, fp8_dtype),
|
||||
(256, 64, 7168, fp8_dtype),
|
||||
|
||||
# Only add a few fnuz tests to help with long CI times.
|
||||
(8, 512, 7168, torch.float8_e4m3fnuz),
|
||||
(8, 1024, 7168, torch.float8_e4m3fnuz),
|
||||
@@ -52,15 +52,15 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
||||
tokens_per_expert = torch.randint(
|
||||
low=T // 2,
|
||||
high=T,
|
||||
size=(E, ),
|
||||
size=(E,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Run the Triton kernel
|
||||
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y,
|
||||
tokens_per_expert,
|
||||
group_size=group_size)
|
||||
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
y, tokens_per_expert, group_size=group_size
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
fp8_info = torch.finfo(fp8_dtype)
|
||||
@@ -75,9 +75,9 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
||||
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
ref_s = torch.empty((T, cdiv(H, group_size)),
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
ref_s = torch.empty(
|
||||
(T, cdiv(H, group_size)), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda")
|
||||
|
||||
for t in range(nt):
|
||||
@@ -87,14 +87,17 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
||||
# process full groups
|
||||
n_full_groups = H // group_size
|
||||
if n_full_groups > 0:
|
||||
data_grp = data[:n_full_groups * group_size].view(
|
||||
n_full_groups, group_size)
|
||||
data_grp = data[: n_full_groups * group_size].view(
|
||||
n_full_groups, group_size
|
||||
)
|
||||
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
|
||||
scale = amax / fp8_max
|
||||
scaled = data[:n_full_groups *
|
||||
group_size] / scale.repeat_interleave(group_size)
|
||||
ref_q_row[:n_full_groups * group_size] = scaled.clamp(
|
||||
fp8_min, fp8_max).to(fp8_dtype)
|
||||
scaled = data[: n_full_groups * group_size] / scale.repeat_interleave(
|
||||
group_size
|
||||
)
|
||||
ref_q_row[: n_full_groups * group_size] = scaled.clamp(
|
||||
fp8_min, fp8_max
|
||||
).to(fp8_dtype)
|
||||
ref_s[t, :n_full_groups] = scale
|
||||
|
||||
# process remainder group
|
||||
|
||||
@@ -11,13 +11,11 @@ from tests.kernels.moe.utils import fused_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
@@ -31,14 +29,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(
|
||||
), "B must be a 2D contiguous tensor"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K, )
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
@@ -88,17 +85,17 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = ops.scaled_fp8_quant(
|
||||
act_out, use_per_token_if_dynamic=True)
|
||||
act_out, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
output_dtype=a.dtype)
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
@@ -116,8 +113,10 @@ TOP_KS = [2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS))
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
@@ -133,12 +132,10 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min,
|
||||
max=fp8_max).to(torch.float8_e4m3fn)
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min,
|
||||
max=fp8_max).to(torch.float8_e4m3fn)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
@@ -163,7 +160,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
)
|
||||
|
||||
# Check results
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.05
|
||||
|
||||
@@ -6,17 +6,17 @@ import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX)
|
||||
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils import round_up
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
@@ -45,12 +45,7 @@ def triton_moe(
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
@@ -80,10 +75,9 @@ def batched_moe(
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
num_local_experts=w1.shape[0],
|
||||
rank=0),
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
@@ -121,10 +115,9 @@ def naive_batched_moe(
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
num_local_experts=w1.shape[0],
|
||||
rank=0),
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
NaiveBatchedExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
@@ -135,8 +128,9 @@ def naive_batched_moe(
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
end: int) -> Optional[torch.Tensor]:
|
||||
def chunk_scales(
|
||||
scales: Optional[torch.Tensor], start: int, end: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
@@ -159,13 +153,15 @@ def make_quantized_test_activations(
|
||||
a_scale = None
|
||||
|
||||
if quant_dtype is not None:
|
||||
assert (quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
||||
assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, (
|
||||
"only fp8/int8 supported"
|
||||
)
|
||||
a_q = torch.zeros_like(a, dtype=quant_dtype)
|
||||
a_scale_l = [None] * E
|
||||
for e in range(E):
|
||||
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
|
||||
a[e], None, quant_dtype, per_act_token_quant, block_shape)
|
||||
a[e], None, quant_dtype, per_act_token_quant, block_shape
|
||||
)
|
||||
a_scale = torch.stack(a_scale_l)
|
||||
|
||||
if not per_act_token_quant and block_shape is None:
|
||||
@@ -181,8 +177,11 @@ def moe_quantize_weights(
|
||||
per_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
|
||||
or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
|
||||
assert (
|
||||
quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8
|
||||
or quant_dtype == "nvfp4"
|
||||
), "only fp8/int8/nvfp4 supported"
|
||||
|
||||
w_gs = None
|
||||
|
||||
@@ -199,10 +198,12 @@ def moe_quantize_weights(
|
||||
else:
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = ops.scaled_int8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant
|
||||
)
|
||||
elif quant_dtype == torch.float8_e4m3fn:
|
||||
w, w_s = ops.scaled_fp8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant
|
||||
)
|
||||
elif quant_dtype == "nvfp4":
|
||||
assert not per_token_quant
|
||||
w_amax = torch.abs(w).max().to(torch.float32)
|
||||
@@ -222,8 +223,7 @@ def make_test_weight(
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
w_gs = None
|
||||
|
||||
@@ -233,7 +233,8 @@ def make_test_weight(
|
||||
w_gs_l = [None] * e
|
||||
for idx in range(e):
|
||||
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
||||
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
|
||||
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape
|
||||
)
|
||||
|
||||
w = torch.stack(w_l)
|
||||
w_s = torch.stack(w_s_l)
|
||||
@@ -264,26 +265,25 @@ def make_test_weights(
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]]:
|
||||
) -> tuple[
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||
]:
|
||||
return (
|
||||
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||
per_out_ch_quant),
|
||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||
per_out_ch_quant),
|
||||
make_test_weight(
|
||||
e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant
|
||||
),
|
||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
|
||||
)
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x: torch.Tensor, block_size: int = 128
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (block_size - (n % block_size)) % block_size
|
||||
x = torch.nn.functional.pad(x,
|
||||
(0, pad_size), value=0) if pad_size > 0 else x
|
||||
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, block_size)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
@@ -313,27 +313,31 @@ def make_test_quant_config(
|
||||
a1_gscale: Optional[torch.Tensor] = None
|
||||
a2_gscale: Optional[torch.Tensor] = None
|
||||
if quant_dtype == "nvfp4":
|
||||
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
a1_scale = a1_gscale
|
||||
a2_scale = a2_gscale
|
||||
else:
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
return w1, w2, FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
# TODO: make sure this is handled properly
|
||||
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
|
||||
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
|
||||
return (
|
||||
w1,
|
||||
w2,
|
||||
FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
# TODO: make sure this is handled properly
|
||||
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
|
||||
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -348,21 +352,23 @@ def fused_moe(
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
|
||||
renormalize)
|
||||
return fused_experts(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=quant_config)
|
||||
topk_weights, topk_ids, _ = fused_topk(
|
||||
hidden_states, score.float(), topk, renormalize
|
||||
)
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
|
||||
# CustomOp?
|
||||
class BaselineMM(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
b: torch.Tensor,
|
||||
@@ -372,15 +378,11 @@ class BaselineMM(torch.nn.Module):
|
||||
self.b = b.to(dtype=torch.float32)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
return torch.mm(a.to(dtype=torch.float32),
|
||||
self.b).to(self.out_dtype), None
|
||||
def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None
|
||||
|
||||
|
||||
class TestMLP(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
w1: torch.Tensor,
|
||||
@@ -410,7 +412,6 @@ def make_naive_shared_experts(
|
||||
|
||||
|
||||
class RealMLP(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -425,37 +426,48 @@ class RealMLP(torch.nn.Module):
|
||||
w2_s: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, RowParallelLinear)
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w1, requires_grad=False))
|
||||
"weight", torch.nn.Parameter(w1, requires_grad=False)
|
||||
)
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False))
|
||||
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)
|
||||
)
|
||||
self.gate_up_proj.register_parameter(
|
||||
"input_scale",
|
||||
None) #torch.nn.Parameter(None, requires_grad=False))
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
"input_scale", None
|
||||
) # torch.nn.Parameter(None, requires_grad=False))
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.down_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w2, requires_grad=False))
|
||||
"weight", torch.nn.Parameter(w2, requires_grad=False)
|
||||
)
|
||||
self.down_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False))
|
||||
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)
|
||||
)
|
||||
self.down_proj.register_parameter(
|
||||
"input_scale",
|
||||
None) #torch.nn.Parameter(None, requires_grad=False))
|
||||
"input_scale", None
|
||||
) # torch.nn.Parameter(None, requires_grad=False))
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@@ -496,13 +508,6 @@ def make_shared_experts(
|
||||
w2_s = None
|
||||
quant_config = None
|
||||
|
||||
return RealMLP(K,
|
||||
N,
|
||||
w1,
|
||||
w2,
|
||||
"silu",
|
||||
quant_config,
|
||||
w1_s=w1_s,
|
||||
w2_s=w2_s)
|
||||
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
Reference in New Issue
Block a user