diff --git a/tests/kernels/moe/test_cutedsl_moe.py b/tests/kernels/moe/test_cutedsl_moe.py index bca3eba0f..2a6f83695 100644 --- a/tests/kernels/moe/test_cutedsl_moe.py +++ b/tests/kernels/moe/test_cutedsl_moe.py @@ -17,7 +17,7 @@ from flashinfer import fp4_quantize from torch.nn import functional as F from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( +from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501 flashinfer_cutedsl_moe_masked, ) from vllm.utils.flashinfer import ( diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py new file mode 100644 index 000000000..5eaaf4673 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import ( + flashinfer_cutedsl_grouped_gemm_nt_masked, + has_flashinfer_cutedsl_grouped_gemm_nt_masked, + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, +) + +logger = init_logger(__name__) + + +class FlashInferCuteDSLBatchedExperts(mk.FusedMoEExpertsModular): + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + max_num_tokens: int, + num_dispatchers: int, + ): + super().__init__( + moe_config=moe_config, + quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + ) + assert quant_config.quant_dtype == "nvfp4", ( + "Only nvfp4 quantization are currently supported." + ) + self.out_dtype = moe_config.in_dtype + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) + layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + @staticmethod + def _supports_current_device() -> bool: + p = current_platform + return ( + p.is_cuda() + and p.is_device_capability_family(100) + and has_flashinfer_cutedsl_grouped_gemm_nt_masked() + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kNvfp4Static, kNvfp4Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + return activation == MoEActivation.SILU + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return True + + def supports_expert_map(self) -> bool: + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: MoEActivation, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + """ + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. + + Returns a tuple of: + - workspace13 shape tuple: must be large enough to hold the + result of either expert gemm. + - workspace2 shape tuple: must be large enough to hold the + result of the activation function. + - output shape tuple: must be exact size of the final gemm output. + - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. + """ + + # We use global_num_experts due to how moe_align_block_size handles + # expert_maps. + K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K + output_shape = (local_num_experts, M, K_dim) + workspace2 = (local_num_experts, M, N) + workspace1 = output_shape + return (workspace1, workspace2, output_shape) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, # Not used + workspace13: torch.Tensor | None, + workspace2: torch.Tensor | None, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool | None, + ): + assert self.quant_dtype == "nvfp4", ( + "Only nvfp4 quantization are currently supported." + ) + # Ensure w1_scale and w2_scale are not None before calling view + assert self.w1_scale is not None and self.w2_scale is not None, ( + "w1_scale and w2_scale must not be None for FlashInferExperts" + ) + assert expert_tokens_meta is not None + expert_num_tokens = expert_tokens_meta.expert_num_tokens + assert hidden_states.ndim == 3 + assert self.w1_scale.ndim == 3 + assert self.w2_scale.ndim == 3 + + input_global_scale = ( + None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale + ) + flashinfer_hidden_states = ( + (hidden_states, a1q_scale) + if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH + else hidden_states + ) + flashinfer_cutedsl_moe_masked( + hidden_states=flashinfer_hidden_states, + input_global_scale=input_global_scale, + w1=w1, + w1_blockscale=self.w1_scale, + w1_alpha=self.g1_alphas, + w2=w2, + a2_global_scale=self.a2_gscale, + w2_blockscale=self.w2_scale, + w2_alpha=self.g2_alphas, + masked_m=expert_num_tokens, + workspace=workspace2, + out=output, + ) + + +def get_cute_dtype(input: torch.Tensor) -> str: + if input.dtype == torch.bfloat16: + return "bfloat16" + elif input.dtype == torch.float16: + return "float16" + elif input.dtype == torch.float32: + return "float32" + else: + raise ValueError(f"Unsupported cute dtype {input.dtype}") + + +def flashinfer_cutedsl_moe_masked( + hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + input_global_scale: torch.Tensor, + w1: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alpha, + w2: torch.Tensor, + a2_global_scale: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alpha, + masked_m: torch.Tensor, + workspace: torch.Tensor, + out: torch.Tensor, +): + """ + Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL + kernels. + + Args: + hidden_states: Either of the following case + * torch.Tensor: [num_experts, m, k], bf16 + * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], + uint8, [num_experts, m, k // 16], float8_e4m3fn + input_global_scale (torch.Tensor): (l,) + w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 + w1_blockscale (torch.Tensor): blockscale factors, e4m3, + w1_alpha (torch.Tensor): (l,) + w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8 + a2_global_scale (torch.Tensor): (l,) + w2_blockscale (torch.Tensor): blockscale factors, e4m3, + w2_alpha (torch.Tensor): (l,) + masked_m (torch.Tensor): Masked dimension indices + workspace (torch.Tensor): For gateup_output + + Notes: + - Assumes max(masked_m) <= m. + """ + + # === Assertions on dtypes === + assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}" + assert w1_blockscale.dtype == torch.float8_e4m3fn, ( + f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" + ) + assert w1_alpha.dtype == torch.float32, ( + f"w1_alpha must be float32, got {w1_alpha.dtype}" + ) + assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}" + assert a2_global_scale.dtype == torch.float32, ( + f"a2_global_scale must be float32, got {a2_global_scale.dtype}" + ) + assert w2_blockscale.dtype == torch.float8_e4m3fn, ( + f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" + ) + assert w2_alpha.dtype == torch.float32, ( + f"w2_alpha must be float32, got {w2_alpha.dtype}" + ) + + # === Assertions on shapes === + n = w2.shape[-1] * 2 # intermediate dimension + if isinstance(hidden_states, tuple): + assert input_global_scale is None, ( + "input_global_scale is needed when input needs quant" + ) + + aq = hidden_states[0].view(torch.uint8) + aq_sf = hidden_states[1].view(torch.float8_e4m3fn) + # m, k_by_2, num_experts = aq.shape + num_experts, m, k_by_2 = aq.shape + k = k_by_2 * 2 + aq = aq.permute(1, 2, 0) + else: + num_experts, m, k = hidden_states.shape + + assert input_global_scale.dtype == torch.float32, ( + f"input_global_scale must be float32, got {input_global_scale.dtype}" + ) + assert input_global_scale.shape == (num_experts,), ( + f"input_global_scale must be (l,), got {input_global_scale.shape}" + ) + + aq, aq_sf = scaled_fp4_grouped_quantize( + hidden_states, + masked_m, + input_global_scale, + ) + + assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" + assert w1.shape[-1] * 2 == k, ( + f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}" + ) + assert w2.shape[-2:] == ( + k, + n // 2, + ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}" + + assert w1_alpha.shape == (num_experts,), ( + f"w1_alpha must be (l,), got {w1_alpha.shape}" + ) + assert a2_global_scale.shape == (num_experts,), ( + f"a2_global_scale must be (l,), got {a2_global_scale.shape}" + ) + assert w2_alpha.shape == (num_experts,), ( + f"w2_alpha must be (l,), got {w2_alpha.shape}" + ) + + workspace = workspace.permute(1, 2, 0) # requirement of kernel + sf_vec_size = 16 + assert aq_sf.dtype == torch.float8_e4m3fn + assert aq.dtype == torch.uint8 + ab_dtype = "float4_e2m1fn" + sf_dtype = "float8_e4m3fn" + + if isinstance(hidden_states, tuple): + c_dtype = "bfloat16" + else: + c_dtype = get_cute_dtype(hidden_states) + + # Gemm1 + flashinfer_cutedsl_grouped_gemm_nt_masked( + (aq, aq_sf), + (w1.permute(1, 2, 0), w1_blockscale), + workspace, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w1_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w1_alpha), + ) # in logical [m, n, l] + + # SILU and quantization + diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( + workspace.permute(2, 0, 1), + masked_m, + a2_global_scale, + ) + + # Gemm2 + out = out.permute(1, 2, 0) # requirement of kernel + flashinfer_cutedsl_grouped_gemm_nt_masked( + (diq, diq_sf), + (w2.permute(1, 2, 0), w2_blockscale), + out, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w2_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w2_alpha), + ) # in logical [m, k, l] + out = out.permute(2, 0, 1) diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py index a1db26619..5ce58220b 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py @@ -4,8 +4,6 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import envs -from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -13,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate, + TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -22,33 +20,42 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( - flashinfer_cutedsl_grouped_gemm_nt_masked, - has_flashinfer_cutedsl_grouped_gemm_nt_masked, - scaled_fp4_grouped_quantize, - silu_and_mul_scaled_nvfp4_experts_quantize, + flashinfer_cute_dsl_fused_moe_nvfp4, + has_flashinfer_cutedsl_moe_nvfp4, ) -logger = init_logger(__name__) - class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): + """ + CuteDSL NvFP4 MoE experts using the FlashInfer functional API. + + Uses Standard activation format (non-batched). The kernel handles + routing, expert computation, and reduction internally. + Supports expert parallelism natively. + """ + def __init__( self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, - max_num_tokens: int, - num_dispatchers: int, ): super().__init__( moe_config=moe_config, quant_config=quant_config, - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, ) assert quant_config.quant_dtype == "nvfp4", ( - "Only nvfp4 quantization are currently supported." + "Only nvfp4 quantization is currently supported." ) self.out_dtype = moe_config.in_dtype + self.hidden_dim = moe_config.hidden_dim + self.intermediate_size_per_partition = ( + moe_config.intermediate_size_per_partition + ) + self.topk = moe_config.experts_per_token + self.local_num_experts = moe_config.num_local_experts + self.global_num_experts = moe_config.num_experts + self.ep_rank = moe_config.moe_parallel_config.ep_rank + self.local_expert_offset = self.ep_rank * self.local_num_experts def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) @@ -56,7 +63,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.BatchedExperts + return mk.FusedMoEActivationFormat.Standard @staticmethod def _supports_current_device() -> bool: @@ -64,7 +71,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): return ( p.is_cuda() and p.is_device_capability_family(100) - and has_flashinfer_cutedsl_grouped_gemm_nt_masked() + and has_flashinfer_cutedsl_moe_nvfp4() ) @staticmethod @@ -86,15 +93,16 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): return activation == MoEActivation.SILU @staticmethod - def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + def _supports_parallel_config( + moe_parallel_config: FusedMoEParallelConfig, + ) -> bool: return True def supports_expert_map(self) -> bool: return False def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return TopKWeightAndReduceNoOP() def workspace_shapes( self, @@ -107,29 +115,12 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): expert_tokens_meta: mk.ExpertTokensMetadata | None, activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - # We use global_num_experts due to how moe_align_block_size handles - # expert_maps. - """ - Compute the shapes for the temporary and final outputs of the two gemms - and activation in the fused expert function. Since the gemms are - independent, the workspace for the first gemm can be shared with the - workspace for the last gemm. - - Returns a tuple of: - - workspace13 shape tuple: must be large enough to hold the - result of either expert gemm. - - workspace2 shape tuple: must be large enough to hold the - result of the activation function. - - output shape tuple: must be exact size of the final gemm output. - - Workspace type: The dtype to use for the workspace tensors. - - Note: in order for activation chunking to work, the first dimension - of each tuple must be the number of tokens. - """ - K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K - output_shape = (local_num_experts, M, K_dim) - workspace2 = (local_num_experts, M, N) - workspace1 = output_shape - return (workspace1, workspace2, output_shape) + workspace1 = (0,) + workspace2 = (0,) + # K is packed (K//2 for uint8), so output uses hidden_dim. + assert self.hidden_dim == K * 2 + output = (M, self.hidden_dim) + return (workspace1, workspace2, output) def apply( self, @@ -143,210 +134,39 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, # Not used + a2_scale: torch.Tensor | None, workspace13: torch.Tensor | None, workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, ): - assert self.quant_dtype == "nvfp4", ( - "Only nvfp4 quantization are currently supported." - ) - # Ensure w1_scale and w2_scale are not None before calling view - assert self.w1_scale is not None and self.w2_scale is not None, ( - "w1_scale and w2_scale must not be None for FlashInferExperts" - ) - assert expert_tokens_meta is not None - expert_num_tokens = expert_tokens_meta.expert_num_tokens - assert hidden_states.ndim == 3 - assert self.w1_scale.ndim == 3 - assert self.w2_scale.ndim == 3 + assert self.quant_dtype == "nvfp4" + assert a1q_scale is not None + assert self.w1_scale is not None + assert self.w2_scale is not None - input_global_scale = ( - None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale - ) - flashinfer_hidden_states = ( - (hidden_states, a1q_scale) - if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH - else hidden_states - ) - flashinfer_cutedsl_moe_masked( - hidden_states=flashinfer_hidden_states, - input_global_scale=input_global_scale, - w1=w1, - w1_blockscale=self.w1_scale, - w1_alpha=self.g1_alphas, - w2=w2, - a2_global_scale=self.a2_gscale, - w2_blockscale=self.w2_scale, - w2_alpha=self.g2_alphas, - masked_m=expert_num_tokens, - workspace=workspace2, - out=output, - ) + # a1q_scale is (M, K//16) float8_e4m3fn from fp4_quantize. + # The functional API expects x_sf with trailing dim: (M, K//16, 1). + x_sf = a1q_scale.unsqueeze(-1) + from vllm.utils.flashinfer import _is_fi_autotuning, autotune -def get_cute_dtype(input: torch.Tensor) -> str: - if input.dtype == torch.bfloat16: - return "bfloat16" - elif input.dtype == torch.float16: - return "float16" - elif input.dtype == torch.float32: - return "float32" - else: - raise ValueError(f"Unsupported cute dtype {input.dtype}") - - -def flashinfer_cutedsl_moe_masked( - hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], - input_global_scale: torch.Tensor, - w1: torch.Tensor, - w1_blockscale: torch.Tensor, - w1_alpha, - w2: torch.Tensor, - a2_global_scale: torch.Tensor, - w2_blockscale: torch.Tensor, - w2_alpha, - masked_m: torch.Tensor, - workspace: torch.Tensor, - out: torch.Tensor, -): - """ - Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL - kernels. - - Args: - hidden_states: Either of the following case - * torch.Tensor: [num_experts, m, k], bf16 - * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], - uint8, [num_experts, m, k // 16], float8_e4m3fn - input_global_scale (torch.Tensor): (l,) - w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 - w1_blockscale (torch.Tensor): blockscale factors, e4m3, - w1_alpha (torch.Tensor): (l,) - w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8 - a2_global_scale (torch.Tensor): (l,) - w2_blockscale (torch.Tensor): blockscale factors, e4m3, - w2_alpha (torch.Tensor): (l,) - masked_m (torch.Tensor): Masked dimension indices - workspace (torch.Tensor): For gateup_output - - Notes: - - Assumes max(masked_m) <= m. - """ - - # === Assertions on dtypes === - assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}" - assert w1_blockscale.dtype == torch.float8_e4m3fn, ( - f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" - ) - assert w1_alpha.dtype == torch.float32, ( - f"w1_alpha must be float32, got {w1_alpha.dtype}" - ) - assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}" - assert a2_global_scale.dtype == torch.float32, ( - f"a2_global_scale must be float32, got {a2_global_scale.dtype}" - ) - assert w2_blockscale.dtype == torch.float8_e4m3fn, ( - f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" - ) - assert w2_alpha.dtype == torch.float32, ( - f"w2_alpha must be float32, got {w2_alpha.dtype}" - ) - - # === Assertions on shapes === - n = w2.shape[-1] * 2 # intermediate dimension - if isinstance(hidden_states, tuple): - assert input_global_scale is None, ( - "input_global_scale is needed when input needs quant" - ) - - aq = hidden_states[0].view(torch.uint8) - aq_sf = hidden_states[1].view(torch.float8_e4m3fn) - # m, k_by_2, num_experts = aq.shape - num_experts, m, k_by_2 = aq.shape - k = k_by_2 * 2 - aq = aq.permute(1, 2, 0) - else: - num_experts, m, k = hidden_states.shape - - assert input_global_scale.dtype == torch.float32, ( - f"input_global_scale must be float32, got {input_global_scale.dtype}" - ) - assert input_global_scale.shape == (num_experts,), ( - f"input_global_scale must be (l,), got {input_global_scale.shape}" - ) - - aq, aq_sf = scaled_fp4_grouped_quantize( - hidden_states, - masked_m, - input_global_scale, - ) - - assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" - assert w1.shape[-1] * 2 == k, ( - f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}" - ) - assert w2.shape[-2:] == ( - k, - n // 2, - ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}" - - assert w1_alpha.shape == (num_experts,), ( - f"w1_alpha must be (l,), got {w1_alpha.shape}" - ) - assert a2_global_scale.shape == (num_experts,), ( - f"a2_global_scale must be (l,), got {a2_global_scale.shape}" - ) - assert w2_alpha.shape == (num_experts,), ( - f"w2_alpha must be (l,), got {w2_alpha.shape}" - ) - - workspace = workspace.permute(1, 2, 0) # requirement of kernel - sf_vec_size = 16 - assert aq_sf.dtype == torch.float8_e4m3fn - assert aq.dtype == torch.uint8 - ab_dtype = "float4_e2m1fn" - sf_dtype = "float8_e4m3fn" - - if isinstance(hidden_states, tuple): - c_dtype = "bfloat16" - else: - c_dtype = get_cute_dtype(hidden_states) - - # Gemm1 - flashinfer_cutedsl_grouped_gemm_nt_masked( - (aq, aq_sf), - (w1.permute(1, 2, 0), w1_blockscale), - workspace, - masked_m, - ab_dtype=ab_dtype, - sf_dtype=sf_dtype, - c_dtype=c_dtype, - sf_vec_size=sf_vec_size, - alpha=w1_alpha.view(1, 1, num_experts), - alpha_dtype=get_cute_dtype(w1_alpha), - ) # in logical [m, n, l] - - # SILU and quantization - diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( - workspace.permute(2, 0, 1), - masked_m, - a2_global_scale, - ) - - # Gemm2 - out = out.permute(1, 2, 0) # requirement of kernel - flashinfer_cutedsl_grouped_gemm_nt_masked( - (diq, diq_sf), - (w2.permute(1, 2, 0), w2_blockscale), - out, - masked_m, - ab_dtype=ab_dtype, - sf_dtype=sf_dtype, - c_dtype=c_dtype, - sf_vec_size=sf_vec_size, - alpha=w2_alpha.view(1, 1, num_experts), - alpha_dtype=get_cute_dtype(w2_alpha), - ) # in logical [m, k, l] - out = out.permute(2, 0, 1) + with autotune(_is_fi_autotuning): + flashinfer_cute_dsl_fused_moe_nvfp4( + x=hidden_states, + x_sf=x_sf, + token_selected_experts=topk_ids.to(torch.int32), + token_final_scales=topk_weights.float(), + w1_weight=w1, + w1_weight_sf=self.w1_scale, + w1_alpha=self.g1_alphas, + fc2_input_scale=self.a2_gscale, + w2_weight=w2, + w2_weight_sf=self.w2_scale, + w2_alpha=self.g2_alphas, + num_experts=self.global_num_experts, + top_k=self.topk, + num_local_experts=self.local_num_experts, + local_expert_offset=self.local_expert_offset, + moe_output=output, + ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index d946c5eb5..597d784d3 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( ) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( prepare_nvfp4_moe_layer_for_fi_or_cutlass, + prepare_nvfp4_moe_layer_for_flashinfer_cutedsl, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -41,6 +42,7 @@ class NvFp4MoeBackend(Enum): FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL" + FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED" VLLM_CUTLASS = "VLLM_CUTLASS" MARLIN = "MARLIN" @@ -49,6 +51,7 @@ FLASHINFER_NVFP4_MOE_BACKENDS = [ NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTEDSL, + NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, ] fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = { @@ -95,6 +98,13 @@ def backend_to_kernel_cls( return [FlashInferCuteDSLExperts] + elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED: + from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501 + FlashInferCuteDSLBatchedExperts, + ) + + return [FlashInferCuteDSLBatchedExperts] + elif backend == NvFp4MoeBackend.VLLM_CUTLASS: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassExpertsFp4, @@ -143,6 +153,7 @@ def select_nvfp4_moe_backend( AVAILABLE_BACKENDS = [ NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_CUTEDSL, + NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.VLLM_CUTLASS, NvFp4MoeBackend.MARLIN, @@ -198,6 +209,12 @@ def select_nvfp4_moe_backend( runner_backend = config.moe_backend if runner_backend != "auto": requested_backend = map_nvfp4_backend(runner_backend) + # For batched activation format, use batched variant if available. + if ( + activation_format == mk.FusedMoEActivationFormat.BatchedExperts + and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL + ): + requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED return _return_or_raise( requested_backend, config, weight_key, activation_key, activation_format ) @@ -288,7 +305,28 @@ def convert_to_nvfp4_moe_kernel_format( torch.Tensor, torch.Tensor, ]: - if ( + if nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl( + layer=layer, + w13=w13, + w13_scale=w13_scale, + w13_scale_2=w13_scale_2, + a13_scale=a13_scale, + w2=w2, + w2_scale=w2_scale, + w2_scale_2=w2_scale_2, + a2_scale=a2_scale, + ) + elif ( nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS ): @@ -380,7 +418,13 @@ def make_nvfp4_moe_quant_config( # NOTE(rob): this is a hack until the MoE kernels # create their own quant configs. TRTLLM kernel # does not accept swizzled input quant scales. - is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM), + is_nvfp4_scale_swizzled=( + backend + not in ( + NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, + ) + ), ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index d16d4a3d2..397442aec 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -60,6 +60,100 @@ def reorder_w1w3_to_w3w1( ) +def interleave_linear_and_gate( + x: torch.Tensor, + group_size: int = 64, + dim: int = -1, +) -> torch.Tensor: + """Interleave gate and linear weight rows for CuteDSL wrapper.""" + sizes = x.size() + dim = dim % x.dim() + assert sizes[dim] % (group_size * 2) == 0, ( + f"dim {dim} size {sizes[dim]} must be divisible by {group_size * 2}" + ) + prev_sizes = sizes[:dim] + post_sizes = sizes[dim + 1 :] + x = x.view(*prev_sizes, 2, sizes[dim] // (group_size * 2), group_size, *post_sizes) + x = x.transpose(dim, dim + 1).contiguous().view(*sizes) + return x + + +def prepare_nvfp4_moe_layer_for_flashinfer_cutedsl( + layer: "FusedMoE", + w13: torch.Tensor, + w13_scale: torch.Tensor, + w13_scale_2: torch.Tensor, + a13_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_scale_2: torch.Tensor, + a2_scale: torch.Tensor, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """Prepare weights for the CuteDSL wrapper-based NvFP4 MoE backend. + + Converts weight scale factors to MMA layout expected by CuteDslMoEWrapper, + and interleaves w13 gate/linear rows. + """ + from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout + + # Global scaling factors (same as other FlashInfer backends). + num_experts = w13.shape[0] + a13_scale = a13_scale.max().to(torch.float32).expand(num_experts) + a2_scale = a2_scale.max().to(torch.float32).expand(num_experts) + + half = w13.shape[1] // 2 + w13 = torch.cat([w13[:, half:], w13[:, :half]], dim=1) + w13_scale = torch.cat([w13_scale[:, half:], w13_scale[:, :half]], dim=1) + + # Interleave up/gate rows for w13 weights and scales. + w13 = interleave_linear_and_gate(w13, group_size=64, dim=1) + w13_scale = interleave_linear_and_gate(w13_scale, group_size=64, dim=1) + + # Convert w13 scale factors: linear → swizzled → MMA layout. + w13_scale = swizzle_blockscale(w13_scale) + E, M_padded, K_sf_padded = w13_scale.shape + w13_scale_flat = w13_scale.reshape(E * M_padded, K_sf_padded) + w13_scale = convert_sf_to_mma_layout( + w13_scale_flat, + m=M_padded, + k=K_sf_padded * 16, + num_groups=E, + sf_vec_size=16, + ) + + # Convert w2 scale factors: linear → swizzled → MMA layout. + w2_scale = swizzle_blockscale(w2_scale) + E, M_padded, K_sf_padded = w2_scale.shape + w2_scale_flat = w2_scale.reshape(E * M_padded, K_sf_padded) + w2_scale = convert_sf_to_mma_layout( + w2_scale_flat, + m=M_padded, + k=K_sf_padded * 16, + num_groups=E, + sf_vec_size=16, + ) + + return ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) + + def prepare_static_weights_for_trtllm_fp4_moe( # args_dequant, # args, @@ -221,7 +315,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( NvFp4MoeBackend.VLLM_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_TRTLLM, - NvFp4MoeBackend.FLASHINFER_CUTEDSL, + NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, ] # Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels. diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 8ffac48cc..373134e65 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -128,6 +128,12 @@ scaled_fp4_grouped_quantize = _lazy_import_wrapper( nvfp4_block_scale_interleave = _lazy_import_wrapper( "flashinfer.fp4_quantization", "block_scale_interleave" ) +flashinfer_cute_dsl_fused_moe_nvfp4 = _lazy_import_wrapper( + "flashinfer", "cute_dsl_fused_moe_nvfp4" +) +flashinfer_convert_sf_to_mma_layout = _lazy_import_wrapper( + "flashinfer.cute_dsl.utils", "convert_sf_to_mma_layout" +) trtllm_fp4_block_scale_moe = _lazy_import_wrapper( "flashinfer", "trtllm_fp4_block_scale_moe" ) @@ -252,6 +258,15 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool: return True +@functools.cache +def has_flashinfer_cutedsl_moe_nvfp4() -> bool: + """Return ``True`` if FlashInfer cute_dsl_fused_moe_nvfp4 is available.""" + if not has_flashinfer_cutedsl(): + return False + mod = _get_submodule("flashinfer") + return mod is not None and hasattr(mod, "cute_dsl_fused_moe_nvfp4") + + @functools.cache def has_nvidia_artifactory() -> bool: """Return `True` if NVIDIA's artifactory is accessible. @@ -768,6 +783,8 @@ __all__ = [ "silu_and_mul_scaled_nvfp4_experts_quantize", "scaled_fp4_grouped_quantize", "nvfp4_block_scale_interleave", + "flashinfer_cute_dsl_fused_moe_nvfp4", + "flashinfer_convert_sf_to_mma_layout", "trtllm_fp4_block_scale_moe", "autotune", "has_flashinfer_moe", @@ -776,6 +793,7 @@ __all__ = [ "has_flashinfer_nvlink_one_sided", "has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", + "has_flashinfer_cutedsl_moe_nvfp4", "has_flashinfer_fp8_blockscale_gemm", "has_nvidia_artifactory", "supports_trtllm_attention",