[MoE] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked (#25990)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Shu Wang
2025-11-19 15:29:06 -06:00
committed by GitHub
parent cdeec2e606
commit 613abb50d5
10 changed files with 1064 additions and 35 deletions

View File

@@ -6,6 +6,7 @@ import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@@ -27,6 +28,8 @@ logger = init_logger(__name__)
DEEPEP_QUANT_BLOCK_SIZE = 128
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
logger = init_logger(__name__)
def dequant_fp8(
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
@@ -187,16 +190,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
q_dtype = quant_config.quant_dtype
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
logger.info_once(
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
"for ModelOptNvFp4FusedMoE."
)
q_dtype = None
x, x_scales = moe_kernel_quantize_input(
x,
quant_config.a1_scale,
quant_config.quant_dtype,
q_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
x = x.view((num_experts, -1, hidden_dim))
if quant_config.quant_dtype is not None:
if q_dtype is not None:
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)

View File

@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
logger = init_logger(__name__)
def is_valid_flashinfer_cutedsl_fused_moe(
hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
) -> bool:
"""
Check if the given problem size is supported by the FlashInfer CuteDSL MoE
kernel.
"""
if not has_flashinfer_cutedsl_grouped_gemm_nt_masked():
logger.debug_once(
"FlashInferCuteDSLExperts disabled: "
"flashinfer_cutedsl_fused_moe not available."
)
return False
# Data type checks
if (
w1.dtype != torch.uint8
or w2.dtype != torch.uint8
or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16]
):
logger.debug_once(
"FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 "
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
f"float32, float16, or bfloat16 (got {hidden_states.dtype})."
)
return False
return True
class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
out_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
assert quant_config.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
self.out_dtype = out_dtype
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts,
)
def supports_expert_map(self) -> bool:
return False
def supports_chunking(self) -> bool:
# This refers to TP chunking; DP chunking is handled separately.
# TODO(shuw@nvidia.com): Set to False to be consistent with
# batched_deep_gemm_moe
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
output_shape = (local_num_experts, M, K)
workspace2 = (local_num_experts, M, N)
workspace1 = output_shape
return (workspace1, workspace2, output_shape)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None, # Not used
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
assert self.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
# Ensure w1_scale and w2_scale are not None before calling view
assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not be None for FlashInferExperts"
)
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
assert hidden_states.ndim == 3
assert self.w1_scale.ndim == 3
assert self.w2_scale.ndim == 3
flashinfer_cutedsl_moe_masked(
hidden_states=hidden_states,
input_global_scale=self.a1_gscale,
w1=w1,
w1_blockscale=self.w1_scale,
w1_alpha=self.g1_alphas,
w2=w2,
a2_global_scale=self.a2_gscale,
w2_blockscale=self.w2_scale,
w2_alpha=self.g2_alphas,
masked_m=expert_num_tokens,
workspace=workspace2,
out=output,
)
def get_cute_dtype(input: torch.Tensor) -> str:
if input.dtype == torch.bfloat16:
return "bfloat16"
elif input.dtype == torch.float16:
return "float16"
elif input.dtype == torch.float32:
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")
def flashinfer_cutedsl_moe_masked(
hidden_states: torch.Tensor,
input_global_scale: torch.Tensor,
w1: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alpha,
w2: torch.Tensor,
a2_global_scale: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alpha,
masked_m: torch.Tensor,
workspace: torch.Tensor,
out: torch.Tensor,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.
Args:
hidden_states (torch.Tensor): [num_experts, m, k], bf16
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices
workspace (torch.Tensor): For gateup_output
Notes:
- Assumes max(masked_m) <= m.
"""
# === Assertions on dtypes ===
assert input_global_scale.dtype == torch.float32, (
f"input_global_scale must be float32, got {input_global_scale.dtype}"
)
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
)
assert w1_alpha.dtype == torch.float32, (
f"w1_alpha must be float32, got {w1_alpha.dtype}"
)
assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}"
assert a2_global_scale.dtype == torch.float32, (
f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
)
assert w2_blockscale.dtype == torch.float8_e4m3fn, (
f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
)
assert w2_alpha.dtype == torch.float32, (
f"w2_alpha must be float32, got {w2_alpha.dtype}"
)
# === Assertions on shapes ===
n = w2.shape[-1] * 2 # intermediate dimension
num_experts, m, k = hidden_states.shape
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
assert w1.shape[-1] * 2 == k, (
f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
)
assert w2.shape[-2:] == (
k,
n // 2,
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
assert input_global_scale.shape == (num_experts,), (
f"input_global_scale must be (l,), got {input_global_scale.shape}"
)
assert w1_alpha.shape == (num_experts,), (
f"w1_alpha must be (l,), got {w1_alpha.shape}"
)
assert a2_global_scale.shape == (num_experts,), (
f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
)
assert w2_alpha.shape == (num_experts,), (
f"w2_alpha must be (l,), got {w2_alpha.shape}"
)
aq, aq_sf = scaled_fp4_grouped_quantize(
hidden_states,
masked_m,
input_global_scale,
)
workspace = workspace.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
assert aq_sf.dtype == torch.float8_e4m3fn
assert aq.dtype == torch.uint8
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
c_dtype = get_cute_dtype(hidden_states)
# Gemm1
flashinfer_cutedsl_grouped_gemm_nt_masked(
(aq, aq_sf),
(w1.permute(1, 2, 0), w1_blockscale),
workspace,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w1_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w1_alpha),
) # in logical [m, n, l]
# SILU and quantization
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
workspace.permute(2, 0, 1),
masked_m,
a2_global_scale,
)
# Gemm2
out = out.permute(1, 2, 0) # requirement of kernel
flashinfer_cutedsl_grouped_gemm_nt_masked(
(diq, diq_sf),
(w2.permute(1, 2, 0), w2_blockscale),
out,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w2_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
out = out.permute(2, 0, 1)
def flashinfer_cutedsl_moe_fp4(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_config: FusedMoEQuantConfig,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
fused_experts = mk.FusedMoEModularKernel(
create_flashinfer_prepare_finalize(use_dp=False), # could be swapped later
FlashInferCuteDSLExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
),
)
return fused_experts(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@@ -1468,7 +1468,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
if self.allow_flashinfer:
if (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
gemm1_weight, gemm1_weight_scale, dim=-2
)
@@ -1746,17 +1749,26 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
workspace=layer.workspace,
)
elif (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
elif self.allow_flashinfer:
assert self.flashinfer_moe_backend in (
FlashinferMoeBackend.CUTLASS,
FlashinferMoeBackend.CUTEDSL,
)
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)
flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4
else:
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501
flashinfer_cutedsl_moe_fp4,
)
flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4
assert self.moe_quant_config is not None
return flashinfer_cutlass_moe_fp4(
return flashinfer_fn_moe_fp4(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,

View File

@@ -10,6 +10,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
@@ -17,10 +20,14 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize im
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe,
)
__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
@@ -36,6 +43,16 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool:
)
def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
and current_platform.is_cuda()
and current_platform.is_device_capability(100)
)
def reorder_w1w3_to_w3w1(
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -72,15 +89,21 @@ def select_nvfp4_gemm_impl(
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
if allow_flashinfer:
return FlashInferExperts(
out_dtype=moe.in_dtype,
quant_config=moe_quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
)
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
return FlashInferCuteDSLExperts(
out_dtype=moe.in_dtype,
quant_config=moe_quant_config,
)
elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput":
return FlashInferExperts(
out_dtype=moe.in_dtype,
quant_config=moe_quant_config,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
)
# native cutlass experts currently don't support DP; TP case won't call this
raise ValueError(

View File

@@ -25,6 +25,7 @@ logger = init_logger(__name__)
class FlashinferMoeBackend(Enum):
TENSORRT_LLM = "TensorRT-LLM"
CUTLASS = "CUTLASS"
CUTEDSL = "CUTEDSL"
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
@@ -273,19 +274,21 @@ def flashinfer_cutlass_moe_fp8(
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
# Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
90
):
return FlashinferMoeBackend.CUTLASS
elif flashinfer_moe_backend == "latency":
return FlashinferMoeBackend.TENSORRT_LLM
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
"latency": FlashinferMoeBackend.TENSORRT_LLM,
"masked_gemm": FlashinferMoeBackend.CUTEDSL,
}
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_moe_backend in backend_map:
return backend_map[flashinfer_moe_backend]
elif current_platform.is_device_capability(90):
return FlashinferMoeBackend.CUTLASS
allowed_backends = ["throughput", "latency"]
raise ValueError(
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
f" expected one of {allowed_backends}"
f"Unknown flashinfer moe backend: {flashinfer_moe_backend!r}. "
f"Expected one of {list(backend_map.keys())}."
)

View File

@@ -5,6 +5,7 @@ from dataclasses import dataclass
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_flashinfer_fp4_cutedsl_moe_available,
is_flashinfer_fp4_cutlass_moe_available,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
@@ -32,7 +33,10 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
"""Detect platform support for NV-FP4 fused-MoE path"""
cutlass_supported = cutlass_fp4_supported()
allow_flashinfer = cutlass_supported and is_flashinfer_fp4_cutlass_moe_available()
allow_flashinfer = cutlass_supported and (
is_flashinfer_fp4_cutlass_moe_available()
or is_flashinfer_fp4_cutedsl_moe_available()
)
if allow_flashinfer:
_logger.info_once(