[Quantization] Add FlashInfer CuteDSL batched experts backend for NVFP4 MoE (#38251)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Yongye Zhu
2026-04-06 14:57:53 -04:00
committed by GitHub
parent 94fbb09894
commit e8ebbdde83
6 changed files with 574 additions and 245 deletions

View File

@@ -17,7 +17,7 @@ from flashinfer import fp4_quantize
from torch.nn import functional as F from torch.nn import functional as F
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501
flashinfer_cutedsl_moe_masked, flashinfer_cutedsl_moe_masked,
) )
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (

View File

@@ -0,0 +1,353 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
logger = init_logger(__name__)
class FlashInferCuteDSLBatchedExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
):
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert quant_config.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
self.out_dtype = moe_config.in_dtype
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SILU
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K
output_shape = (local_num_experts, M, K_dim)
workspace2 = (local_num_experts, M, N)
workspace1 = output_shape
return (workspace1, workspace2, output_shape)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None, # Not used
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
assert self.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported."
)
# Ensure w1_scale and w2_scale are not None before calling view
assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not be None for FlashInferExperts"
)
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
assert hidden_states.ndim == 3
assert self.w1_scale.ndim == 3
assert self.w2_scale.ndim == 3
input_global_scale = (
None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale
)
flashinfer_hidden_states = (
(hidden_states, a1q_scale)
if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
else hidden_states
)
flashinfer_cutedsl_moe_masked(
hidden_states=flashinfer_hidden_states,
input_global_scale=input_global_scale,
w1=w1,
w1_blockscale=self.w1_scale,
w1_alpha=self.g1_alphas,
w2=w2,
a2_global_scale=self.a2_gscale,
w2_blockscale=self.w2_scale,
w2_alpha=self.g2_alphas,
masked_m=expert_num_tokens,
workspace=workspace2,
out=output,
)
def get_cute_dtype(input: torch.Tensor) -> str:
if input.dtype == torch.bfloat16:
return "bfloat16"
elif input.dtype == torch.float16:
return "float16"
elif input.dtype == torch.float32:
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")
def flashinfer_cutedsl_moe_masked(
hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
input_global_scale: torch.Tensor,
w1: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alpha,
w2: torch.Tensor,
a2_global_scale: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alpha,
masked_m: torch.Tensor,
workspace: torch.Tensor,
out: torch.Tensor,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.
Args:
hidden_states: Either of the following case
* torch.Tensor: [num_experts, m, k], bf16
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2],
uint8, [num_experts, m, k // 16], float8_e4m3fn
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices
workspace (torch.Tensor): For gateup_output
Notes:
- Assumes max(masked_m) <= m.
"""
# === Assertions on dtypes ===
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
)
assert w1_alpha.dtype == torch.float32, (
f"w1_alpha must be float32, got {w1_alpha.dtype}"
)
assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}"
assert a2_global_scale.dtype == torch.float32, (
f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
)
assert w2_blockscale.dtype == torch.float8_e4m3fn, (
f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
)
assert w2_alpha.dtype == torch.float32, (
f"w2_alpha must be float32, got {w2_alpha.dtype}"
)
# === Assertions on shapes ===
n = w2.shape[-1] * 2 # intermediate dimension
if isinstance(hidden_states, tuple):
assert input_global_scale is None, (
"input_global_scale is needed when input needs quant"
)
aq = hidden_states[0].view(torch.uint8)
aq_sf = hidden_states[1].view(torch.float8_e4m3fn)
# m, k_by_2, num_experts = aq.shape
num_experts, m, k_by_2 = aq.shape
k = k_by_2 * 2
aq = aq.permute(1, 2, 0)
else:
num_experts, m, k = hidden_states.shape
assert input_global_scale.dtype == torch.float32, (
f"input_global_scale must be float32, got {input_global_scale.dtype}"
)
assert input_global_scale.shape == (num_experts,), (
f"input_global_scale must be (l,), got {input_global_scale.shape}"
)
aq, aq_sf = scaled_fp4_grouped_quantize(
hidden_states,
masked_m,
input_global_scale,
)
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
assert w1.shape[-1] * 2 == k, (
f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
)
assert w2.shape[-2:] == (
k,
n // 2,
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
assert w1_alpha.shape == (num_experts,), (
f"w1_alpha must be (l,), got {w1_alpha.shape}"
)
assert a2_global_scale.shape == (num_experts,), (
f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
)
assert w2_alpha.shape == (num_experts,), (
f"w2_alpha must be (l,), got {w2_alpha.shape}"
)
workspace = workspace.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
assert aq_sf.dtype == torch.float8_e4m3fn
assert aq.dtype == torch.uint8
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
if isinstance(hidden_states, tuple):
c_dtype = "bfloat16"
else:
c_dtype = get_cute_dtype(hidden_states)
# Gemm1
flashinfer_cutedsl_grouped_gemm_nt_masked(
(aq, aq_sf),
(w1.permute(1, 2, 0), w1_blockscale),
workspace,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w1_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w1_alpha),
) # in logical [m, n, l]
# SILU and quantization
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
workspace.permute(2, 0, 1),
masked_m,
a2_global_scale,
)
# Gemm2
out = out.permute(1, 2, 0) # requirement of kernel
flashinfer_cutedsl_grouped_gemm_nt_masked(
(diq, diq_sf),
(w2.permute(1, 2, 0), w2_blockscale),
out,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w2_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
out = out.permute(2, 0, 1)

View File

@@ -4,8 +4,6 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
@@ -13,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
@@ -22,33 +20,42 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked, flashinfer_cute_dsl_fused_moe_nvfp4,
has_flashinfer_cutedsl_grouped_gemm_nt_masked, has_flashinfer_cutedsl_moe_nvfp4,
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
) )
logger = init_logger(__name__)
class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
"""
CuteDSL NvFP4 MoE experts using the FlashInfer functional API.
Uses Standard activation format (non-batched). The kernel handles
routing, expert computation, and reduction internally.
Supports expert parallelism natively.
"""
def __init__( def __init__(
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
): ):
super().__init__( super().__init__(
moe_config=moe_config, moe_config=moe_config,
quant_config=quant_config, quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
) )
assert quant_config.quant_dtype == "nvfp4", ( assert quant_config.quant_dtype == "nvfp4", (
"Only nvfp4 quantization are currently supported." "Only nvfp4 quantization is currently supported."
) )
self.out_dtype = moe_config.in_dtype self.out_dtype = moe_config.in_dtype
self.hidden_dim = moe_config.hidden_dim
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.topk = moe_config.experts_per_token
self.local_num_experts = moe_config.num_local_experts
self.global_num_experts = moe_config.num_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.local_expert_offset = self.ep_rank * self.local_num_experts
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale)
@@ -56,7 +63,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts return mk.FusedMoEActivationFormat.Standard
@staticmethod @staticmethod
def _supports_current_device() -> bool: def _supports_current_device() -> bool:
@@ -64,7 +71,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
return ( return (
p.is_cuda() p.is_cuda()
and p.is_device_capability_family(100) and p.is_device_capability_family(100)
and has_flashinfer_cutedsl_grouped_gemm_nt_masked() and has_flashinfer_cutedsl_moe_nvfp4()
) )
@staticmethod @staticmethod
@@ -86,15 +93,16 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
return activation == MoEActivation.SILU return activation == MoEActivation.SILU
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return True return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return False return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceNoOP()
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
@@ -107,29 +115,12 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation, activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles workspace1 = (0,)
# expert_maps. workspace2 = (0,)
""" # K is packed (K//2 for uint8), so output uses hidden_dim.
Compute the shapes for the temporary and final outputs of the two gemms assert self.hidden_dim == K * 2
and activation in the fused expert function. Since the gemms are output = (M, self.hidden_dim)
independent, the workspace for the first gemm can be shared with the return (workspace1, workspace2, output)
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K
output_shape = (local_num_experts, M, K_dim)
workspace2 = (local_num_experts, M, N)
workspace1 = output_shape
return (workspace1, workspace2, output_shape)
def apply( def apply(
self, self,
@@ -143,210 +134,39 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
global_num_experts: int, global_num_experts: int,
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None, a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None, # Not used a2_scale: torch.Tensor | None,
workspace13: torch.Tensor | None, workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None, workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None, apply_router_weight_on_input: bool | None,
): ):
assert self.quant_dtype == "nvfp4", ( assert self.quant_dtype == "nvfp4"
"Only nvfp4 quantization are currently supported." assert a1q_scale is not None
) assert self.w1_scale is not None
# Ensure w1_scale and w2_scale are not None before calling view assert self.w2_scale is not None
assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not be None for FlashInferExperts"
)
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
assert hidden_states.ndim == 3
assert self.w1_scale.ndim == 3
assert self.w2_scale.ndim == 3
input_global_scale = ( # a1q_scale is (M, K//16) float8_e4m3fn from fp4_quantize.
None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale # The functional API expects x_sf with trailing dim: (M, K//16, 1).
) x_sf = a1q_scale.unsqueeze(-1)
flashinfer_hidden_states = (
(hidden_states, a1q_scale) from vllm.utils.flashinfer import _is_fi_autotuning, autotune
if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
else hidden_states with autotune(_is_fi_autotuning):
) flashinfer_cute_dsl_fused_moe_nvfp4(
flashinfer_cutedsl_moe_masked( x=hidden_states,
hidden_states=flashinfer_hidden_states, x_sf=x_sf,
input_global_scale=input_global_scale, token_selected_experts=topk_ids.to(torch.int32),
w1=w1, token_final_scales=topk_weights.float(),
w1_blockscale=self.w1_scale, w1_weight=w1,
w1_weight_sf=self.w1_scale,
w1_alpha=self.g1_alphas, w1_alpha=self.g1_alphas,
w2=w2, fc2_input_scale=self.a2_gscale,
a2_global_scale=self.a2_gscale, w2_weight=w2,
w2_blockscale=self.w2_scale, w2_weight_sf=self.w2_scale,
w2_alpha=self.g2_alphas, w2_alpha=self.g2_alphas,
masked_m=expert_num_tokens, num_experts=self.global_num_experts,
workspace=workspace2, top_k=self.topk,
out=output, num_local_experts=self.local_num_experts,
local_expert_offset=self.local_expert_offset,
moe_output=output,
) )
def get_cute_dtype(input: torch.Tensor) -> str:
if input.dtype == torch.bfloat16:
return "bfloat16"
elif input.dtype == torch.float16:
return "float16"
elif input.dtype == torch.float32:
return "float32"
else:
raise ValueError(f"Unsupported cute dtype {input.dtype}")
def flashinfer_cutedsl_moe_masked(
hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
input_global_scale: torch.Tensor,
w1: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alpha,
w2: torch.Tensor,
a2_global_scale: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alpha,
masked_m: torch.Tensor,
workspace: torch.Tensor,
out: torch.Tensor,
):
"""
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
kernels.
Args:
hidden_states: Either of the following case
* torch.Tensor: [num_experts, m, k], bf16
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2],
uint8, [num_experts, m, k // 16], float8_e4m3fn
input_global_scale (torch.Tensor): (l,)
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
w1_alpha (torch.Tensor): (l,)
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
a2_global_scale (torch.Tensor): (l,)
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
w2_alpha (torch.Tensor): (l,)
masked_m (torch.Tensor): Masked dimension indices
workspace (torch.Tensor): For gateup_output
Notes:
- Assumes max(masked_m) <= m.
"""
# === Assertions on dtypes ===
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
)
assert w1_alpha.dtype == torch.float32, (
f"w1_alpha must be float32, got {w1_alpha.dtype}"
)
assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}"
assert a2_global_scale.dtype == torch.float32, (
f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
)
assert w2_blockscale.dtype == torch.float8_e4m3fn, (
f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
)
assert w2_alpha.dtype == torch.float32, (
f"w2_alpha must be float32, got {w2_alpha.dtype}"
)
# === Assertions on shapes ===
n = w2.shape[-1] * 2 # intermediate dimension
if isinstance(hidden_states, tuple):
assert input_global_scale is None, (
"input_global_scale is needed when input needs quant"
)
aq = hidden_states[0].view(torch.uint8)
aq_sf = hidden_states[1].view(torch.float8_e4m3fn)
# m, k_by_2, num_experts = aq.shape
num_experts, m, k_by_2 = aq.shape
k = k_by_2 * 2
aq = aq.permute(1, 2, 0)
else:
num_experts, m, k = hidden_states.shape
assert input_global_scale.dtype == torch.float32, (
f"input_global_scale must be float32, got {input_global_scale.dtype}"
)
assert input_global_scale.shape == (num_experts,), (
f"input_global_scale must be (l,), got {input_global_scale.shape}"
)
aq, aq_sf = scaled_fp4_grouped_quantize(
hidden_states,
masked_m,
input_global_scale,
)
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
assert w1.shape[-1] * 2 == k, (
f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
)
assert w2.shape[-2:] == (
k,
n // 2,
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
assert w1_alpha.shape == (num_experts,), (
f"w1_alpha must be (l,), got {w1_alpha.shape}"
)
assert a2_global_scale.shape == (num_experts,), (
f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
)
assert w2_alpha.shape == (num_experts,), (
f"w2_alpha must be (l,), got {w2_alpha.shape}"
)
workspace = workspace.permute(1, 2, 0) # requirement of kernel
sf_vec_size = 16
assert aq_sf.dtype == torch.float8_e4m3fn
assert aq.dtype == torch.uint8
ab_dtype = "float4_e2m1fn"
sf_dtype = "float8_e4m3fn"
if isinstance(hidden_states, tuple):
c_dtype = "bfloat16"
else:
c_dtype = get_cute_dtype(hidden_states)
# Gemm1
flashinfer_cutedsl_grouped_gemm_nt_masked(
(aq, aq_sf),
(w1.permute(1, 2, 0), w1_blockscale),
workspace,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w1_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w1_alpha),
) # in logical [m, n, l]
# SILU and quantization
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
workspace.permute(2, 0, 1),
masked_m,
a2_global_scale,
)
# Gemm2
out = out.permute(1, 2, 0) # requirement of kernel
flashinfer_cutedsl_grouped_gemm_nt_masked(
(diq, diq_sf),
(w2.permute(1, 2, 0), w2_blockscale),
out,
masked_m,
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
sf_vec_size=sf_vec_size,
alpha=w2_alpha.view(1, 1, num_experts),
alpha_dtype=get_cute_dtype(w2_alpha),
) # in logical [m, k, l]
out = out.permute(2, 0, 1)

View File

@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
prepare_nvfp4_moe_layer_for_fi_or_cutlass, prepare_nvfp4_moe_layer_for_fi_or_cutlass,
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl,
) )
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend, FlashinferMoeBackend,
@@ -41,6 +42,7 @@ class NvFp4MoeBackend(Enum):
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL" FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED"
VLLM_CUTLASS = "VLLM_CUTLASS" VLLM_CUTLASS = "VLLM_CUTLASS"
MARLIN = "MARLIN" MARLIN = "MARLIN"
@@ -49,6 +51,7 @@ FLASHINFER_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_CUTEDSL, NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
] ]
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = { fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
@@ -95,6 +98,13 @@ def backend_to_kernel_cls(
return [FlashInferCuteDSLExperts] return [FlashInferCuteDSLExperts]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED:
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501
FlashInferCuteDSLBatchedExperts,
)
return [FlashInferCuteDSLBatchedExperts]
elif backend == NvFp4MoeBackend.VLLM_CUTLASS: elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4, CutlassExpertsFp4,
@@ -143,6 +153,7 @@ def select_nvfp4_moe_backend(
AVAILABLE_BACKENDS = [ AVAILABLE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL, NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.VLLM_CUTLASS, NvFp4MoeBackend.VLLM_CUTLASS,
NvFp4MoeBackend.MARLIN, NvFp4MoeBackend.MARLIN,
@@ -198,6 +209,12 @@ def select_nvfp4_moe_backend(
runner_backend = config.moe_backend runner_backend = config.moe_backend
if runner_backend != "auto": if runner_backend != "auto":
requested_backend = map_nvfp4_backend(runner_backend) requested_backend = map_nvfp4_backend(runner_backend)
# For batched activation format, use batched variant if available.
if (
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL
):
requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED
return _return_or_raise( return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format requested_backend, config, weight_key, activation_key, activation_format
) )
@@ -288,7 +305,28 @@ def convert_to_nvfp4_moe_kernel_format(
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
]: ]:
if ( if nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
a13_scale=a13_scale,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
a2_scale=a2_scale,
)
elif (
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
): ):
@@ -380,7 +418,13 @@ def make_nvfp4_moe_quant_config(
# NOTE(rob): this is a hack until the MoE kernels # NOTE(rob): this is a hack until the MoE kernels
# create their own quant configs. TRTLLM kernel # create their own quant configs. TRTLLM kernel
# does not accept swizzled input quant scales. # does not accept swizzled input quant scales.
is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM), is_nvfp4_scale_swizzled=(
backend
not in (
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
)
),
) )

View File

@@ -60,6 +60,100 @@ def reorder_w1w3_to_w3w1(
) )
def interleave_linear_and_gate(
x: torch.Tensor,
group_size: int = 64,
dim: int = -1,
) -> torch.Tensor:
"""Interleave gate and linear weight rows for CuteDSL wrapper."""
sizes = x.size()
dim = dim % x.dim()
assert sizes[dim] % (group_size * 2) == 0, (
f"dim {dim} size {sizes[dim]} must be divisible by {group_size * 2}"
)
prev_sizes = sizes[:dim]
post_sizes = sizes[dim + 1 :]
x = x.view(*prev_sizes, 2, sizes[dim] // (group_size * 2), group_size, *post_sizes)
x = x.transpose(dim, dim + 1).contiguous().view(*sizes)
return x
def prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
layer: "FusedMoE",
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
a13_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
a2_scale: torch.Tensor,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""Prepare weights for the CuteDSL wrapper-based NvFP4 MoE backend.
Converts weight scale factors to MMA layout expected by CuteDslMoEWrapper,
and interleaves w13 gate/linear rows.
"""
from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout
# Global scaling factors (same as other FlashInfer backends).
num_experts = w13.shape[0]
a13_scale = a13_scale.max().to(torch.float32).expand(num_experts)
a2_scale = a2_scale.max().to(torch.float32).expand(num_experts)
half = w13.shape[1] // 2
w13 = torch.cat([w13[:, half:], w13[:, :half]], dim=1)
w13_scale = torch.cat([w13_scale[:, half:], w13_scale[:, :half]], dim=1)
# Interleave up/gate rows for w13 weights and scales.
w13 = interleave_linear_and_gate(w13, group_size=64, dim=1)
w13_scale = interleave_linear_and_gate(w13_scale, group_size=64, dim=1)
# Convert w13 scale factors: linear → swizzled → MMA layout.
w13_scale = swizzle_blockscale(w13_scale)
E, M_padded, K_sf_padded = w13_scale.shape
w13_scale_flat = w13_scale.reshape(E * M_padded, K_sf_padded)
w13_scale = convert_sf_to_mma_layout(
w13_scale_flat,
m=M_padded,
k=K_sf_padded * 16,
num_groups=E,
sf_vec_size=16,
)
# Convert w2 scale factors: linear → swizzled → MMA layout.
w2_scale = swizzle_blockscale(w2_scale)
E, M_padded, K_sf_padded = w2_scale.shape
w2_scale_flat = w2_scale.reshape(E * M_padded, K_sf_padded)
w2_scale = convert_sf_to_mma_layout(
w2_scale_flat,
m=M_padded,
k=K_sf_padded * 16,
num_groups=E,
sf_vec_size=16,
)
return (
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
)
def prepare_static_weights_for_trtllm_fp4_moe( def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant, # args_dequant,
# args, # args,
@@ -221,7 +315,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
NvFp4MoeBackend.VLLM_CUTLASS, NvFp4MoeBackend.VLLM_CUTLASS,
NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL, NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
] ]
# Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels. # Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels.

View File

@@ -128,6 +128,12 @@ scaled_fp4_grouped_quantize = _lazy_import_wrapper(
nvfp4_block_scale_interleave = _lazy_import_wrapper( nvfp4_block_scale_interleave = _lazy_import_wrapper(
"flashinfer.fp4_quantization", "block_scale_interleave" "flashinfer.fp4_quantization", "block_scale_interleave"
) )
flashinfer_cute_dsl_fused_moe_nvfp4 = _lazy_import_wrapper(
"flashinfer", "cute_dsl_fused_moe_nvfp4"
)
flashinfer_convert_sf_to_mma_layout = _lazy_import_wrapper(
"flashinfer.cute_dsl.utils", "convert_sf_to_mma_layout"
)
trtllm_fp4_block_scale_moe = _lazy_import_wrapper( trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
"flashinfer", "trtllm_fp4_block_scale_moe" "flashinfer", "trtllm_fp4_block_scale_moe"
) )
@@ -252,6 +258,15 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
return True return True
@functools.cache
def has_flashinfer_cutedsl_moe_nvfp4() -> bool:
"""Return ``True`` if FlashInfer cute_dsl_fused_moe_nvfp4 is available."""
if not has_flashinfer_cutedsl():
return False
mod = _get_submodule("flashinfer")
return mod is not None and hasattr(mod, "cute_dsl_fused_moe_nvfp4")
@functools.cache @functools.cache
def has_nvidia_artifactory() -> bool: def has_nvidia_artifactory() -> bool:
"""Return `True` if NVIDIA's artifactory is accessible. """Return `True` if NVIDIA's artifactory is accessible.
@@ -768,6 +783,8 @@ __all__ = [
"silu_and_mul_scaled_nvfp4_experts_quantize", "silu_and_mul_scaled_nvfp4_experts_quantize",
"scaled_fp4_grouped_quantize", "scaled_fp4_grouped_quantize",
"nvfp4_block_scale_interleave", "nvfp4_block_scale_interleave",
"flashinfer_cute_dsl_fused_moe_nvfp4",
"flashinfer_convert_sf_to_mma_layout",
"trtllm_fp4_block_scale_moe", "trtllm_fp4_block_scale_moe",
"autotune", "autotune",
"has_flashinfer_moe", "has_flashinfer_moe",
@@ -776,6 +793,7 @@ __all__ = [
"has_flashinfer_nvlink_one_sided", "has_flashinfer_nvlink_one_sided",
"has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutlass_fused_moe",
"has_flashinfer_cutedsl_grouped_gemm_nt_masked", "has_flashinfer_cutedsl_grouped_gemm_nt_masked",
"has_flashinfer_cutedsl_moe_nvfp4",
"has_flashinfer_fp8_blockscale_gemm", "has_flashinfer_fp8_blockscale_gemm",
"has_nvidia_artifactory", "has_nvidia_artifactory",
"supports_trtllm_attention", "supports_trtllm_attention",