diff --git a/tests/kernels/moe/test_cutedsl_moe.py b/tests/kernels/moe/test_cutedsl_moe.py index 2a6f83695..bca3eba0f 100644 --- a/tests/kernels/moe/test_cutedsl_moe.py +++ b/tests/kernels/moe/test_cutedsl_moe.py @@ -17,7 +17,7 @@ from flashinfer import fp4_quantize from torch.nn import functional as F from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501 +from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( flashinfer_cutedsl_moe_masked, ) from vllm.utils.flashinfer import ( diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py deleted file mode 100644 index 5eaaf4673..000000000 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_batched_moe.py +++ /dev/null @@ -1,353 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import envs -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.activation import MoEActivation -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEParallelConfig, - FusedMoEQuantConfig, -) -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, - kNvfp4Dynamic, - kNvfp4Static, -) -from vllm.platforms import current_platform -from vllm.utils.flashinfer import ( - flashinfer_cutedsl_grouped_gemm_nt_masked, - has_flashinfer_cutedsl_grouped_gemm_nt_masked, - scaled_fp4_grouped_quantize, - silu_and_mul_scaled_nvfp4_experts_quantize, -) - -logger = init_logger(__name__) - - -class FlashInferCuteDSLBatchedExperts(mk.FusedMoEExpertsModular): - def __init__( - self, - moe_config: FusedMoEConfig, - quant_config: FusedMoEQuantConfig, - max_num_tokens: int, - num_dispatchers: int, - ): - super().__init__( - moe_config=moe_config, - quant_config=quant_config, - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - ) - assert quant_config.quant_dtype == "nvfp4", ( - "Only nvfp4 quantization are currently supported." - ) - self.out_dtype = moe_config.in_dtype - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) - layer.w2_weight_scale_2.data.mul_(layer.w2_input_scale) - - @staticmethod - def activation_format() -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.BatchedExperts - - @staticmethod - def _supports_current_device() -> bool: - p = current_platform - return ( - p.is_cuda() - and p.is_device_capability_family(100) - and has_flashinfer_cutedsl_grouped_gemm_nt_masked() - ) - - @staticmethod - def _supports_no_act_and_mul() -> bool: - return False - - @staticmethod - def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - SUPPORTED_W_A = [ - (kNvfp4Static, kNvfp4Dynamic), - ] - return (weight_key, activation_key) in SUPPORTED_W_A - - @staticmethod - def _supports_activation(activation: MoEActivation) -> bool: - return activation == MoEActivation.SILU - - @staticmethod - def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return True - - def supports_expert_map(self) -> bool: - return False - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() - - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - global_num_experts: int, - local_num_experts: int, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: MoEActivation, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - """ - Compute the shapes for the temporary and final outputs of the two gemms - and activation in the fused expert function. Since the gemms are - independent, the workspace for the first gemm can be shared with the - workspace for the last gemm. - - Returns a tuple of: - - workspace13 shape tuple: must be large enough to hold the - result of either expert gemm. - - workspace2 shape tuple: must be large enough to hold the - result of the activation function. - - output shape tuple: must be exact size of the final gemm output. - - Workspace type: The dtype to use for the workspace tensors. - - Note: in order for activation chunking to work, the first dimension - of each tuple must be the number of tokens. - """ - - # We use global_num_experts due to how moe_align_block_size handles - # expert_maps. - K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K - output_shape = (local_num_experts, M, K_dim) - workspace2 = (local_num_experts, M, N) - workspace1 = output_shape - return (workspace1, workspace2, output_shape) - - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: MoEActivation, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, # Not used - workspace13: torch.Tensor | None, - workspace2: torch.Tensor | None, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool | None, - ): - assert self.quant_dtype == "nvfp4", ( - "Only nvfp4 quantization are currently supported." - ) - # Ensure w1_scale and w2_scale are not None before calling view - assert self.w1_scale is not None and self.w2_scale is not None, ( - "w1_scale and w2_scale must not be None for FlashInferExperts" - ) - assert expert_tokens_meta is not None - expert_num_tokens = expert_tokens_meta.expert_num_tokens - assert hidden_states.ndim == 3 - assert self.w1_scale.ndim == 3 - assert self.w2_scale.ndim == 3 - - input_global_scale = ( - None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale - ) - flashinfer_hidden_states = ( - (hidden_states, a1q_scale) - if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH - else hidden_states - ) - flashinfer_cutedsl_moe_masked( - hidden_states=flashinfer_hidden_states, - input_global_scale=input_global_scale, - w1=w1, - w1_blockscale=self.w1_scale, - w1_alpha=self.g1_alphas, - w2=w2, - a2_global_scale=self.a2_gscale, - w2_blockscale=self.w2_scale, - w2_alpha=self.g2_alphas, - masked_m=expert_num_tokens, - workspace=workspace2, - out=output, - ) - - -def get_cute_dtype(input: torch.Tensor) -> str: - if input.dtype == torch.bfloat16: - return "bfloat16" - elif input.dtype == torch.float16: - return "float16" - elif input.dtype == torch.float32: - return "float32" - else: - raise ValueError(f"Unsupported cute dtype {input.dtype}") - - -def flashinfer_cutedsl_moe_masked( - hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], - input_global_scale: torch.Tensor, - w1: torch.Tensor, - w1_blockscale: torch.Tensor, - w1_alpha, - w2: torch.Tensor, - a2_global_scale: torch.Tensor, - w2_blockscale: torch.Tensor, - w2_alpha, - masked_m: torch.Tensor, - workspace: torch.Tensor, - out: torch.Tensor, -): - """ - Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL - kernels. - - Args: - hidden_states: Either of the following case - * torch.Tensor: [num_experts, m, k], bf16 - * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], - uint8, [num_experts, m, k // 16], float8_e4m3fn - input_global_scale (torch.Tensor): (l,) - w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 - w1_blockscale (torch.Tensor): blockscale factors, e4m3, - w1_alpha (torch.Tensor): (l,) - w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8 - a2_global_scale (torch.Tensor): (l,) - w2_blockscale (torch.Tensor): blockscale factors, e4m3, - w2_alpha (torch.Tensor): (l,) - masked_m (torch.Tensor): Masked dimension indices - workspace (torch.Tensor): For gateup_output - - Notes: - - Assumes max(masked_m) <= m. - """ - - # === Assertions on dtypes === - assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}" - assert w1_blockscale.dtype == torch.float8_e4m3fn, ( - f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" - ) - assert w1_alpha.dtype == torch.float32, ( - f"w1_alpha must be float32, got {w1_alpha.dtype}" - ) - assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}" - assert a2_global_scale.dtype == torch.float32, ( - f"a2_global_scale must be float32, got {a2_global_scale.dtype}" - ) - assert w2_blockscale.dtype == torch.float8_e4m3fn, ( - f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" - ) - assert w2_alpha.dtype == torch.float32, ( - f"w2_alpha must be float32, got {w2_alpha.dtype}" - ) - - # === Assertions on shapes === - n = w2.shape[-1] * 2 # intermediate dimension - if isinstance(hidden_states, tuple): - assert input_global_scale is None, ( - "input_global_scale is needed when input needs quant" - ) - - aq = hidden_states[0].view(torch.uint8) - aq_sf = hidden_states[1].view(torch.float8_e4m3fn) - # m, k_by_2, num_experts = aq.shape - num_experts, m, k_by_2 = aq.shape - k = k_by_2 * 2 - aq = aq.permute(1, 2, 0) - else: - num_experts, m, k = hidden_states.shape - - assert input_global_scale.dtype == torch.float32, ( - f"input_global_scale must be float32, got {input_global_scale.dtype}" - ) - assert input_global_scale.shape == (num_experts,), ( - f"input_global_scale must be (l,), got {input_global_scale.shape}" - ) - - aq, aq_sf = scaled_fp4_grouped_quantize( - hidden_states, - masked_m, - input_global_scale, - ) - - assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" - assert w1.shape[-1] * 2 == k, ( - f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}" - ) - assert w2.shape[-2:] == ( - k, - n // 2, - ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}" - - assert w1_alpha.shape == (num_experts,), ( - f"w1_alpha must be (l,), got {w1_alpha.shape}" - ) - assert a2_global_scale.shape == (num_experts,), ( - f"a2_global_scale must be (l,), got {a2_global_scale.shape}" - ) - assert w2_alpha.shape == (num_experts,), ( - f"w2_alpha must be (l,), got {w2_alpha.shape}" - ) - - workspace = workspace.permute(1, 2, 0) # requirement of kernel - sf_vec_size = 16 - assert aq_sf.dtype == torch.float8_e4m3fn - assert aq.dtype == torch.uint8 - ab_dtype = "float4_e2m1fn" - sf_dtype = "float8_e4m3fn" - - if isinstance(hidden_states, tuple): - c_dtype = "bfloat16" - else: - c_dtype = get_cute_dtype(hidden_states) - - # Gemm1 - flashinfer_cutedsl_grouped_gemm_nt_masked( - (aq, aq_sf), - (w1.permute(1, 2, 0), w1_blockscale), - workspace, - masked_m, - ab_dtype=ab_dtype, - sf_dtype=sf_dtype, - c_dtype=c_dtype, - sf_vec_size=sf_vec_size, - alpha=w1_alpha.view(1, 1, num_experts), - alpha_dtype=get_cute_dtype(w1_alpha), - ) # in logical [m, n, l] - - # SILU and quantization - diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( - workspace.permute(2, 0, 1), - masked_m, - a2_global_scale, - ) - - # Gemm2 - out = out.permute(1, 2, 0) # requirement of kernel - flashinfer_cutedsl_grouped_gemm_nt_masked( - (diq, diq_sf), - (w2.permute(1, 2, 0), w2_blockscale), - out, - masked_m, - ab_dtype=ab_dtype, - sf_dtype=sf_dtype, - c_dtype=c_dtype, - sf_vec_size=sf_vec_size, - alpha=w2_alpha.view(1, 1, num_experts), - alpha_dtype=get_cute_dtype(w2_alpha), - ) # in logical [m, k, l] - out = out.permute(2, 0, 1) diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py index 5ce58220b..a1db26619 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutedsl_moe.py @@ -4,6 +4,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -11,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP, + TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -20,42 +22,33 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( - flashinfer_cute_dsl_fused_moe_nvfp4, - has_flashinfer_cutedsl_moe_nvfp4, + flashinfer_cutedsl_grouped_gemm_nt_masked, + has_flashinfer_cutedsl_grouped_gemm_nt_masked, + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, ) +logger = init_logger(__name__) + class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): - """ - CuteDSL NvFP4 MoE experts using the FlashInfer functional API. - - Uses Standard activation format (non-batched). The kernel handles - routing, expert computation, and reduction internally. - Supports expert parallelism natively. - """ - def __init__( self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, + max_num_tokens: int, + num_dispatchers: int, ): super().__init__( moe_config=moe_config, quant_config=quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, ) assert quant_config.quant_dtype == "nvfp4", ( - "Only nvfp4 quantization is currently supported." + "Only nvfp4 quantization are currently supported." ) self.out_dtype = moe_config.in_dtype - self.hidden_dim = moe_config.hidden_dim - self.intermediate_size_per_partition = ( - moe_config.intermediate_size_per_partition - ) - self.topk = moe_config.experts_per_token - self.local_num_experts = moe_config.num_local_experts - self.global_num_experts = moe_config.num_experts - self.ep_rank = moe_config.moe_parallel_config.ep_rank - self.local_expert_offset = self.ep_rank * self.local_num_experts def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale_2.data.mul_(layer.w13_input_scale) @@ -63,7 +56,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard + return mk.FusedMoEActivationFormat.BatchedExperts @staticmethod def _supports_current_device() -> bool: @@ -71,7 +64,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): return ( p.is_cuda() and p.is_device_capability_family(100) - and has_flashinfer_cutedsl_moe_nvfp4() + and has_flashinfer_cutedsl_grouped_gemm_nt_masked() ) @staticmethod @@ -93,16 +86,15 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): return activation == MoEActivation.SILU @staticmethod - def _supports_parallel_config( - moe_parallel_config: FusedMoEParallelConfig, - ) -> bool: + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return True def supports_expert_map(self) -> bool: return False def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - return TopKWeightAndReduceNoOP() + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() def workspace_shapes( self, @@ -115,12 +107,29 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): expert_tokens_meta: mk.ExpertTokensMetadata | None, activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - workspace1 = (0,) - workspace2 = (0,) - # K is packed (K//2 for uint8), so output uses hidden_dim. - assert self.hidden_dim == K * 2 - output = (M, self.hidden_dim) - return (workspace1, workspace2, output) + # We use global_num_experts due to how moe_align_block_size handles + # expert_maps. + """ + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. + + Returns a tuple of: + - workspace13 shape tuple: must be large enough to hold the + result of either expert gemm. + - workspace2 shape tuple: must be large enough to hold the + result of the activation function. + - output shape tuple: must be exact size of the final gemm output. + - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. + """ + K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K + output_shape = (local_num_experts, M, K_dim) + workspace2 = (local_num_experts, M, N) + workspace1 = output_shape + return (workspace1, workspace2, output_shape) def apply( self, @@ -134,39 +143,210 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, # Not used workspace13: torch.Tensor | None, workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, ): - assert self.quant_dtype == "nvfp4" - assert a1q_scale is not None - assert self.w1_scale is not None - assert self.w2_scale is not None + assert self.quant_dtype == "nvfp4", ( + "Only nvfp4 quantization are currently supported." + ) + # Ensure w1_scale and w2_scale are not None before calling view + assert self.w1_scale is not None and self.w2_scale is not None, ( + "w1_scale and w2_scale must not be None for FlashInferExperts" + ) + assert expert_tokens_meta is not None + expert_num_tokens = expert_tokens_meta.expert_num_tokens + assert hidden_states.ndim == 3 + assert self.w1_scale.ndim == 3 + assert self.w2_scale.ndim == 3 - # a1q_scale is (M, K//16) float8_e4m3fn from fp4_quantize. - # The functional API expects x_sf with trailing dim: (M, K//16, 1). - x_sf = a1q_scale.unsqueeze(-1) + input_global_scale = ( + None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale + ) + flashinfer_hidden_states = ( + (hidden_states, a1q_scale) + if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH + else hidden_states + ) + flashinfer_cutedsl_moe_masked( + hidden_states=flashinfer_hidden_states, + input_global_scale=input_global_scale, + w1=w1, + w1_blockscale=self.w1_scale, + w1_alpha=self.g1_alphas, + w2=w2, + a2_global_scale=self.a2_gscale, + w2_blockscale=self.w2_scale, + w2_alpha=self.g2_alphas, + masked_m=expert_num_tokens, + workspace=workspace2, + out=output, + ) - from vllm.utils.flashinfer import _is_fi_autotuning, autotune - with autotune(_is_fi_autotuning): - flashinfer_cute_dsl_fused_moe_nvfp4( - x=hidden_states, - x_sf=x_sf, - token_selected_experts=topk_ids.to(torch.int32), - token_final_scales=topk_weights.float(), - w1_weight=w1, - w1_weight_sf=self.w1_scale, - w1_alpha=self.g1_alphas, - fc2_input_scale=self.a2_gscale, - w2_weight=w2, - w2_weight_sf=self.w2_scale, - w2_alpha=self.g2_alphas, - num_experts=self.global_num_experts, - top_k=self.topk, - num_local_experts=self.local_num_experts, - local_expert_offset=self.local_expert_offset, - moe_output=output, - ) +def get_cute_dtype(input: torch.Tensor) -> str: + if input.dtype == torch.bfloat16: + return "bfloat16" + elif input.dtype == torch.float16: + return "float16" + elif input.dtype == torch.float32: + return "float32" + else: + raise ValueError(f"Unsupported cute dtype {input.dtype}") + + +def flashinfer_cutedsl_moe_masked( + hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + input_global_scale: torch.Tensor, + w1: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alpha, + w2: torch.Tensor, + a2_global_scale: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alpha, + masked_m: torch.Tensor, + workspace: torch.Tensor, + out: torch.Tensor, +): + """ + Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL + kernels. + + Args: + hidden_states: Either of the following case + * torch.Tensor: [num_experts, m, k], bf16 + * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], + uint8, [num_experts, m, k // 16], float8_e4m3fn + input_global_scale (torch.Tensor): (l,) + w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 + w1_blockscale (torch.Tensor): blockscale factors, e4m3, + w1_alpha (torch.Tensor): (l,) + w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8 + a2_global_scale (torch.Tensor): (l,) + w2_blockscale (torch.Tensor): blockscale factors, e4m3, + w2_alpha (torch.Tensor): (l,) + masked_m (torch.Tensor): Masked dimension indices + workspace (torch.Tensor): For gateup_output + + Notes: + - Assumes max(masked_m) <= m. + """ + + # === Assertions on dtypes === + assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}" + assert w1_blockscale.dtype == torch.float8_e4m3fn, ( + f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" + ) + assert w1_alpha.dtype == torch.float32, ( + f"w1_alpha must be float32, got {w1_alpha.dtype}" + ) + assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}" + assert a2_global_scale.dtype == torch.float32, ( + f"a2_global_scale must be float32, got {a2_global_scale.dtype}" + ) + assert w2_blockscale.dtype == torch.float8_e4m3fn, ( + f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" + ) + assert w2_alpha.dtype == torch.float32, ( + f"w2_alpha must be float32, got {w2_alpha.dtype}" + ) + + # === Assertions on shapes === + n = w2.shape[-1] * 2 # intermediate dimension + if isinstance(hidden_states, tuple): + assert input_global_scale is None, ( + "input_global_scale is needed when input needs quant" + ) + + aq = hidden_states[0].view(torch.uint8) + aq_sf = hidden_states[1].view(torch.float8_e4m3fn) + # m, k_by_2, num_experts = aq.shape + num_experts, m, k_by_2 = aq.shape + k = k_by_2 * 2 + aq = aq.permute(1, 2, 0) + else: + num_experts, m, k = hidden_states.shape + + assert input_global_scale.dtype == torch.float32, ( + f"input_global_scale must be float32, got {input_global_scale.dtype}" + ) + assert input_global_scale.shape == (num_experts,), ( + f"input_global_scale must be (l,), got {input_global_scale.shape}" + ) + + aq, aq_sf = scaled_fp4_grouped_quantize( + hidden_states, + masked_m, + input_global_scale, + ) + + assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" + assert w1.shape[-1] * 2 == k, ( + f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}" + ) + assert w2.shape[-2:] == ( + k, + n // 2, + ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}" + + assert w1_alpha.shape == (num_experts,), ( + f"w1_alpha must be (l,), got {w1_alpha.shape}" + ) + assert a2_global_scale.shape == (num_experts,), ( + f"a2_global_scale must be (l,), got {a2_global_scale.shape}" + ) + assert w2_alpha.shape == (num_experts,), ( + f"w2_alpha must be (l,), got {w2_alpha.shape}" + ) + + workspace = workspace.permute(1, 2, 0) # requirement of kernel + sf_vec_size = 16 + assert aq_sf.dtype == torch.float8_e4m3fn + assert aq.dtype == torch.uint8 + ab_dtype = "float4_e2m1fn" + sf_dtype = "float8_e4m3fn" + + if isinstance(hidden_states, tuple): + c_dtype = "bfloat16" + else: + c_dtype = get_cute_dtype(hidden_states) + + # Gemm1 + flashinfer_cutedsl_grouped_gemm_nt_masked( + (aq, aq_sf), + (w1.permute(1, 2, 0), w1_blockscale), + workspace, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w1_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w1_alpha), + ) # in logical [m, n, l] + + # SILU and quantization + diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( + workspace.permute(2, 0, 1), + masked_m, + a2_global_scale, + ) + + # Gemm2 + out = out.permute(1, 2, 0) # requirement of kernel + flashinfer_cutedsl_grouped_gemm_nt_masked( + (diq, diq_sf), + (w2.permute(1, 2, 0), w2_blockscale), + out, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w2_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w2_alpha), + ) # in logical [m, k, l] + out = out.permute(2, 0, 1) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index e50c27d93..35451e87d 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import ( ) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( prepare_nvfp4_moe_layer_for_fi_or_cutlass, - prepare_nvfp4_moe_layer_for_flashinfer_cutedsl, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -39,7 +38,6 @@ class NvFp4MoeBackend(Enum): FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL" - FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED" VLLM_CUTLASS = "VLLM_CUTLASS" MARLIN = "MARLIN" @@ -48,7 +46,6 @@ FLASHINFER_NVFP4_MOE_BACKENDS = [ NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTEDSL, - NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, ] fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = { @@ -95,13 +92,6 @@ def backend_to_kernel_cls( return [FlashInferCuteDSLExperts] - elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED: - from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501 - FlashInferCuteDSLBatchedExperts, - ) - - return [FlashInferCuteDSLBatchedExperts] - elif backend == NvFp4MoeBackend.VLLM_CUTLASS: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassExpertsFp4, @@ -150,7 +140,6 @@ def select_nvfp4_moe_backend( AVAILABLE_BACKENDS = [ NvFp4MoeBackend.FLASHINFER_TRTLLM, NvFp4MoeBackend.FLASHINFER_CUTEDSL, - NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.VLLM_CUTLASS, NvFp4MoeBackend.MARLIN, @@ -206,12 +195,6 @@ def select_nvfp4_moe_backend( runner_backend = config.moe_backend if runner_backend != "auto": requested_backend = map_nvfp4_backend(runner_backend) - # For batched activation format, use batched variant if available. - if ( - activation_format == mk.FusedMoEActivationFormat.BatchedExperts - and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL - ): - requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED return _return_or_raise( requested_backend, config, weight_key, activation_key, activation_format ) @@ -302,28 +285,7 @@ def convert_to_nvfp4_moe_kernel_format( torch.Tensor, torch.Tensor, ]: - if nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: - ( - w13, - w13_scale, - w13_scale_2, - a13_scale, - w2, - w2_scale, - w2_scale_2, - a2_scale, - ) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl( - layer=layer, - w13=w13, - w13_scale=w13_scale, - w13_scale_2=w13_scale_2, - a13_scale=a13_scale, - w2=w2, - w2_scale=w2_scale, - w2_scale_2=w2_scale_2, - a2_scale=a2_scale, - ) - elif ( + if ( nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS ): @@ -415,13 +377,7 @@ def make_nvfp4_moe_quant_config( # NOTE(rob): this is a hack until the MoE kernels # create their own quant configs. TRTLLM kernel # does not accept swizzled input quant scales. - is_nvfp4_scale_swizzled=( - backend - not in ( - NvFp4MoeBackend.FLASHINFER_TRTLLM, - NvFp4MoeBackend.FLASHINFER_CUTEDSL, - ) - ), + is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM), ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 5b4f7caa3..66300ceae 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -60,100 +60,6 @@ def reorder_w1w3_to_w3w1( ) -def interleave_linear_and_gate( - x: torch.Tensor, - group_size: int = 64, - dim: int = -1, -) -> torch.Tensor: - """Interleave gate and linear weight rows for CuteDSL wrapper.""" - sizes = x.size() - dim = dim % x.dim() - assert sizes[dim] % (group_size * 2) == 0, ( - f"dim {dim} size {sizes[dim]} must be divisible by {group_size * 2}" - ) - prev_sizes = sizes[:dim] - post_sizes = sizes[dim + 1 :] - x = x.view(*prev_sizes, 2, sizes[dim] // (group_size * 2), group_size, *post_sizes) - x = x.transpose(dim, dim + 1).contiguous().view(*sizes) - return x - - -def prepare_nvfp4_moe_layer_for_flashinfer_cutedsl( - layer: "FusedMoE", - w13: torch.Tensor, - w13_scale: torch.Tensor, - w13_scale_2: torch.Tensor, - a13_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - w2_scale_2: torch.Tensor, - a2_scale: torch.Tensor, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - """Prepare weights for the CuteDSL wrapper-based NvFP4 MoE backend. - - Converts weight scale factors to MMA layout expected by CuteDslMoEWrapper, - and interleaves w13 gate/linear rows. - """ - from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout - - # Global scaling factors (same as other FlashInfer backends). - num_experts = w13.shape[0] - a13_scale = a13_scale.max().to(torch.float32).expand(num_experts) - a2_scale = a2_scale.max().to(torch.float32).expand(num_experts) - - half = w13.shape[1] // 2 - w13 = torch.cat([w13[:, half:], w13[:, :half]], dim=1) - w13_scale = torch.cat([w13_scale[:, half:], w13_scale[:, :half]], dim=1) - - # Interleave up/gate rows for w13 weights and scales. - w13 = interleave_linear_and_gate(w13, group_size=64, dim=1) - w13_scale = interleave_linear_and_gate(w13_scale, group_size=64, dim=1) - - # Convert w13 scale factors: linear → swizzled → MMA layout. - w13_scale = swizzle_blockscale(w13_scale) - E, M_padded, K_sf_padded = w13_scale.shape - w13_scale_flat = w13_scale.reshape(E * M_padded, K_sf_padded) - w13_scale = convert_sf_to_mma_layout( - w13_scale_flat, - m=M_padded, - k=K_sf_padded * 16, - num_groups=E, - sf_vec_size=16, - ) - - # Convert w2 scale factors: linear → swizzled → MMA layout. - w2_scale = swizzle_blockscale(w2_scale) - E, M_padded, K_sf_padded = w2_scale.shape - w2_scale_flat = w2_scale.reshape(E * M_padded, K_sf_padded) - w2_scale = convert_sf_to_mma_layout( - w2_scale_flat, - m=M_padded, - k=K_sf_padded * 16, - num_groups=E, - sf_vec_size=16, - ) - - return ( - w13, - w13_scale, - w13_scale_2, - a13_scale, - w2, - w2_scale, - w2_scale_2, - a2_scale, - ) - - def prepare_static_weights_for_trtllm_fp4_moe( # args_dequant, # args, @@ -315,7 +221,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( NvFp4MoeBackend.VLLM_CUTLASS, NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.FLASHINFER_TRTLLM, - NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, ] # Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels. diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5ae0f3b78..065a9ca89 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -128,12 +128,6 @@ scaled_fp4_grouped_quantize = _lazy_import_wrapper( nvfp4_block_scale_interleave = _lazy_import_wrapper( "flashinfer.fp4_quantization", "block_scale_interleave" ) -flashinfer_cute_dsl_fused_moe_nvfp4 = _lazy_import_wrapper( - "flashinfer", "cute_dsl_fused_moe_nvfp4" -) -flashinfer_convert_sf_to_mma_layout = _lazy_import_wrapper( - "flashinfer.cute_dsl.utils", "convert_sf_to_mma_layout" -) trtllm_fp4_block_scale_moe = _lazy_import_wrapper( "flashinfer", "trtllm_fp4_block_scale_moe" ) @@ -257,15 +251,6 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool: return True -@functools.cache -def has_flashinfer_cutedsl_moe_nvfp4() -> bool: - """Return ``True`` if FlashInfer cute_dsl_fused_moe_nvfp4 is available.""" - if not has_flashinfer_cutedsl(): - return False - mod = _get_submodule("flashinfer") - return mod is not None and hasattr(mod, "cute_dsl_fused_moe_nvfp4") - - @functools.cache def has_nvidia_artifactory() -> bool: """Return `True` if NVIDIA's artifactory is accessible. @@ -782,8 +767,6 @@ __all__ = [ "silu_and_mul_scaled_nvfp4_experts_quantize", "scaled_fp4_grouped_quantize", "nvfp4_block_scale_interleave", - "flashinfer_cute_dsl_fused_moe_nvfp4", - "flashinfer_convert_sf_to_mma_layout", "trtllm_fp4_block_scale_moe", "autotune", "has_flashinfer_moe", @@ -792,7 +775,6 @@ __all__ = [ "has_flashinfer_nvlink_one_sided", "has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", - "has_flashinfer_cutedsl_moe_nvfp4", "has_flashinfer_fp8_blockscale_gemm", "has_nvidia_artifactory", "supports_trtllm_attention",