diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index cd5bf47d6..bd10c3793 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -526,7 +526,7 @@ def test_run_cutlass_moe_fp8( c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) - activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) + activation = "silu" a1q, a1q_scale = moe_kernel_quantize_input( mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token ) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 56507a39b..07ced9769 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -1079,6 +1079,86 @@ def test_fused_marlin_moe_with_bias(m): torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +@pytest.mark.parametrize("m", [1, 64, 256]) +@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)]) +@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)]) +def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int): + """Test Marlin MoE with non-gated activation (relu2_no_mul). + + Non-gated activations like relu2 don't have the gate-up projection pattern, + so w1 has shape (e, n, k) instead of (e, 2*n, k). + """ + torch.cuda.manual_seed(42) + + group_size = 16 # NVFP4 group size + is_k_full = True + quant_type = scalar_types.float4_e2m1f + dtype = torch.bfloat16 + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + # Non-gated: w1 shape is (e, n, k) not (e, 2*n, k) + w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + w1_data = MarlinMoEWeightData.make( + w=w1, + quant_type=quant_type, + group_size=group_size, + act_order=False, + ) + + w2_data = MarlinMoEWeightData.make( + w=w2, + quant_type=quant_type, + group_size=group_size, + act_order=False, + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + + with set_current_vllm_config(vllm_config): + torch_output = torch_moe( + a, + w1_data.w_ref, + w2_data.w_ref, + score, + topk, + activation="relu2", + ) + + marlin_output = fused_marlin_moe( + a, + w1_data.qweight, + w2_data.qweight, + None, # bias1 + None, # bias2 + w1_data.scales, + w2_data.scales, + score, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=None, + global_scale1=w1_data.global_scale, + global_scale2=w2_data.global_scale, + g_idx1=w1_data.g_idx, + g_idx2=w2_data.g_idx, + sort_indices1=w1_data.sort_indices, + sort_indices2=w2_data.sort_indices, + w1_zeros=w1_data.zeros, + w2_zeros=w2_data.zeros, + quant_type_id=quant_type.id, + is_k_full=is_k_full, + activation="relu2_no_mul", + ) + + torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0) + + @pytest.mark.parametrize("ep_size", [1, 2]) def test_moe_align_block_size_opcheck(ep_size): num_experts = 4 diff --git a/vllm/envs.py b/vllm/envs.py index f1ee13e33..1c31e83b7 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1451,6 +1451,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # - "flashinfer-cudnn": use flashinfer cudnn GEMM backend # - "flashinfer-trtllm": use flashinfer trtllm GEMM backend # - "flashinfer-cutlass": use flashinfer cutlass GEMM backend + # - "marlin": use marlin GEMM backend (for GPUs without native FP4 support) # - : automatically pick an available backend "VLLM_NVFP4_GEMM_BACKEND": env_with_choices( "VLLM_NVFP4_GEMM_BACKEND", @@ -1460,6 +1461,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "flashinfer-trtllm", "flashinfer-cutlass", "cutlass", + "marlin", ], ), # Controls garbage collection during CUDA graph capture. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index c0ffa38fd..acc12d0da 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """CUTLASS based Fused MoE kernels.""" -from collections.abc import Callable - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -21,7 +19,10 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, ) -from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, + apply_moe_activation, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -33,7 +34,7 @@ def run_cutlass_moe_fp8( w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, - activation_callable: Callable, + activation: str, global_num_experts: int, expert_map: torch.Tensor | None, w1_scale: torch.Tensor | None, @@ -55,6 +56,7 @@ def run_cutlass_moe_fp8( ): a1q = hidden_states + assert not activation.endswith("_no_mul"), "Only gated activation is supported" assert w1_scale is not None assert w2_scale is not None assert w1.dtype == torch.float8_e4m3fn @@ -198,7 +200,7 @@ def run_cutlass_moe_fp8( per_out_ch, ) - activation_callable(act_out, mm1_out) + apply_moe_activation(activation, act_out, mm1_out) a2q, a2q_scale = ops.scaled_fp8_quant( act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out @@ -288,8 +290,6 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): if expert_tokens_meta is not None: expert_num_tokens = expert_tokens_meta.expert_num_tokens - activation_callable = lambda o, i: self.activation(activation, o, i) - use_batched_format = ( self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts ) @@ -301,7 +301,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): w1, w2, topk_ids, - activation_callable, + activation, global_num_experts, expert_map, self.w1_scale, @@ -436,6 +436,7 @@ def run_cutlass_moe_fp4( w2_alphas: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str, workspace13: torch.Tensor, workspace2: torch.Tensor, m: int, @@ -544,8 +545,7 @@ def run_cutlass_moe_fp4( num_topk, ) c1 = _resize_cache(workspace13, (m * topk, n * 2)) - # Note: c2 workspace is no longer needed since SiLU is fused with quantization. - # c3 reuses workspace13 after c1 is consumed. + c2 = _resize_cache(workspace2, (m * topk, n)) c3 = _resize_cache(workspace13, (m * topk, k)) ops.cutlass_fp4_moe_mm( c1, @@ -559,10 +559,18 @@ def run_cutlass_moe_fp4( blockscale_offsets[:-1], ) del rep_a_fp4, rep_a_blockscale - # Fused SiLU+Mul+NVFP4 quantization - int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant( - c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk - ) + if activation == "silu": + # Fused SiLU+Mul+NVFP4 quantization + # Note: c2 workspace is no longer needed since SiLU is fused with quantization. + # c3 reuses workspace13 after c1 is consumed. + int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant( + c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk + ) + else: + apply_moe_activation(activation, c2, c1) + int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( + c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk + ) ops.cutlass_fp4_moe_mm( c3, @@ -693,6 +701,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): w2_alphas=self.g2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, + activation=activation, workspace13=workspace13, workspace2=workspace2, m=m, @@ -711,7 +720,7 @@ def run_cutlass_moe_w4a8_fp8( w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, - activation_callable: Callable, + activation: str, global_num_experts: int, expert_map: torch.Tensor | None, w1_scale: torch.Tensor | None, @@ -815,7 +824,7 @@ def run_cutlass_moe_w4a8_fp8( s_strides1, ) - activation_callable(act_out, mm1_out) + apply_moe_activation(activation, act_out, mm1_out) a2q, a2q_scale = ops.scaled_fp8_quant( act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out @@ -936,7 +945,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" expert_num_tokens = None - activation_callable = lambda o, i: self.activation(activation, o, i) use_batched_format = ( self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts @@ -951,7 +959,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): w1, w2, topk_ids, - activation_callable, + activation, global_num_experts, expert_map, self.w1_scale, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 77c6b97ea..be9dddb87 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -17,7 +17,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, ) -from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, + apply_moe_activation, + disable_inplace, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, marlin_moe_intermediate_size, @@ -27,21 +31,6 @@ from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types -def default_activation_func( - activation: str, output: torch.Tensor, input: torch.Tensor -) -> None: - if activation == "silu": - torch.ops._C.silu_and_mul(output, input) - elif activation == "swigluoai": - # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul(output, input) - else: - raise ValueError( - f"Unsupported activation: {activation}. " - "Only silu and swigluoai activations are supported." - ) - - def _fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -62,7 +51,7 @@ def _fused_marlin_moe( activation: str = "silu", activation_func: Callable[ [str, torch.Tensor, torch.Tensor], None - ] = default_activation_func, + ] = apply_moe_activation, input_global_scale1: torch.Tensor | None = None, input_global_scale2: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None, @@ -83,13 +72,13 @@ def _fused_marlin_moe( assert hidden_states.ndim == 2 M, K = hidden_states.size() N = marlin_moe_intermediate_size(w1, w2) - + w13_num_shards = 1 if "no_mul" in activation else 2 if workspace is None: workspace = marlin_make_workspace_new(hidden_states.device, 4) if intermediate_cache13 is None: intermediate_cache13 = torch.empty( - (M * num_topk * max(2 * N, K),), + (M * num_topk * max(w13_num_shards * N, K),), device=hidden_states.device, dtype=hidden_states.dtype, ) @@ -101,7 +90,9 @@ def _fused_marlin_moe( dtype=hidden_states.dtype, ) - intermediate_cache1 = _resize_cache(intermediate_cache13, (M * num_topk, 2 * N)) + intermediate_cache1 = _resize_cache( + intermediate_cache13, (M * num_topk, w13_num_shards * N) + ) intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K)) @@ -137,16 +128,17 @@ def _fused_marlin_moe( mul_topk_weights=apply_router_weight_on_input, b_q_type=quant_type, size_m=M, - size_n=2 * N, + size_n=w13_num_shards * N, size_k=K, is_k_full=is_k_full, use_atomic_add=False, use_fp32_reduce=True, is_zp_float=False, ) - activation_func( - activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N) + activation, + intermediate_cache2, + intermediate_cache1.view(-1, w13_num_shards * N), ) if output is None: @@ -216,7 +208,7 @@ def fused_marlin_moe( activation: str = "silu", activation_func: Callable[ [str, torch.Tensor, torch.Tensor], None - ] = default_activation_func, + ] = apply_moe_activation, moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None, expert_map: torch.Tensor | None = None, input_global_scale1: torch.Tensor | None = None, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3b3a789f6..fd3f76cb2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -619,30 +619,11 @@ class FusedMoE(CustomOp): # for heuristic purposes, so it must be initialized first. self.quant_method: FusedMoEMethodBase = _get_quant_method() - if not self.moe_config.is_act_and_mul: - # Avoid circular import - from vllm.model_executor.layers.quantization.modelopt import ( - ModelOptFp8MoEMethod, - ModelOptNvFp4FusedMoE, + if not self.moe_config.is_act_and_mul and not current_platform.is_cuda(): + raise NotImplementedError( + "is_act_and_mul=False is supported only for CUDA for now" ) - if not isinstance( - self.quant_method, - ( - UnquantizedFusedMoEMethod, - ModelOptFp8MoEMethod, - ModelOptNvFp4FusedMoE, - ), - ): - raise NotImplementedError( - "is_act_and_mul=False is supported only for unquantized " - ", ModelOpt FP8, and ModelOpt NvFp4 checkpoints" - ) - if not current_platform.is_cuda(): - raise NotImplementedError( - "is_act_and_mul=False is supported only for CUDA for now" - ) - if self.enable_eplb and not self.quant_method.supports_eplb: # TODO: Add support for additional quantization methods. # The implementation for other quantization methods does not diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index f5c3b9af6..6872b542f 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -52,7 +52,7 @@ def select_fp8_moe_backend( block_quant: bool, tp_size: int, with_lora_support: bool, - is_act_and_mul: bool = True, + is_act_and_mul: bool, allow_vllm_cutlass: bool = False, ) -> Fp8MoeBackend: """ @@ -128,7 +128,7 @@ def select_fp8_moe_backend( scope="local", ) - if use_deep_gemm and moe_use_deep_gemm and block_quant: + if use_deep_gemm and moe_use_deep_gemm and block_quant and is_act_and_mul: if not has_deep_gemm(): logger.warning_once( "DeepGEMM backend requested but not available.", scope="local" @@ -141,7 +141,12 @@ def select_fp8_moe_backend( logger.info_once(_make_log_backend("ROCm AITER"), scope="local") return Fp8MoeBackend.AITER - if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported(): + if ( + allow_vllm_cutlass + and not block_quant + and cutlass_group_gemm_supported() + and is_act_and_mul + ): logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local") return Fp8MoeBackend.VLLM_CUTLASS diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 1efb4d092..f2d69cf09 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -178,6 +178,7 @@ def convert_to_nvfp4_moe_kernel_format( w2=w2, w2_scale=w2_scale, w2_scale_2=w2_scale_2, + is_act_and_mul=is_act_and_mul, ) else: raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}") diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index e74b4fd21..cd89f7c85 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -367,7 +367,8 @@ def apply_moe_activation( elif activation == GELU_NO_MUL: output.copy_(F.gelu(input)) elif activation == RELU2_NO_MUL: - torch.square(F.relu(input), out=output) + F.relu(input, inplace=True) + torch.square(input, out=output) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 5763a4119..829c08e9d 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -764,8 +764,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 85e73e504..423062c61 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -370,12 +370,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): layer_name: str | None = None, use_marlin: bool = False, ): - if not moe.is_act_and_mul: - raise ValueError( - "CompressedTensorsW4A4Nvfp4MoEMethod does not yet " - "support non gated MoE models." - ) - super().__init__(moe) self.group_size = 16 if use_marlin: @@ -388,6 +382,16 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ) else: self.nvfp4_backend = select_nvfp4_moe_backend() + + # TODO: move this type of check into the oracle. + if not self.moe.is_act_and_mul and self.nvfp4_backend not in [ + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.MARLIN, + ]: + raise NotImplementedError( + "Non-gated activations are only supported by FlashInfer " + f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}." + ) self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) @@ -404,11 +408,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ): layer.num_experts = num_experts layer.params_dtype = params_dtype + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 w13_weight = torch.nn.Parameter( torch.empty( num_experts, - 2 * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, requires_grad=False, @@ -436,7 +441,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): w13_weight_scale = torch.nn.Parameter( torch.empty( num_experts, - 2 * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.group_size, dtype=torch.float8_e4m3fn, @@ -467,7 +472,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): # Weight Global Scales w13_weight_scale_2 = torch.nn.Parameter( - torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + torch.empty(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, ) layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) extra_weight_attrs.update( @@ -486,7 +492,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): # Input Global Scales w13_input_scale = torch.nn.Parameter( - torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + torch.empty(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, ) layer.register_parameter("w13_input_global_scale", w13_input_scale) extra_weight_attrs.update( @@ -640,6 +647,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): x=x, router_logits=router_logits, top_k=layer.top_k, + activation=layer.activation, global_num_experts=layer.global_num_experts, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, @@ -666,6 +674,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): topk_ids=topk_ids, topk_weights=topk_weights, top_k=layer.top_k, + activation=layer.activation, global_num_experts=layer.global_num_experts, ) else: @@ -722,6 +731,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): block_quant=self.block_quant, tp_size=moe.tp_size, with_lora_support=moe.is_lora_enabled, + is_act_and_mul=moe.is_act_and_mul, # TODO(rob): enable selecting this externally. allow_vllm_cutlass=True, ) @@ -760,6 +770,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.weight_block_size = None params_dtype = torch.float8_e4m3fn + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 if self.block_quant: assert self.weight_block_size is not None @@ -791,7 +802,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): w13_weight = torch.nn.Parameter( torch.empty( num_experts, - 2 * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, hidden_size, dtype=params_dtype, ), @@ -814,10 +825,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): # WEIGHT_SCALES if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - # Allocate 2 scales for w1 and w3 respectively. - # They are combined to a single scale after weight loading. + # For gated MoE, allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + # For non-gated MoE, allocate 1 scale for w13. w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + torch.ones(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = torch.nn.Parameter( @@ -835,7 +848,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, - 2 * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, 1, dtype=torch.float32, ), @@ -858,7 +871,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, - 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), + w13_num_shards + * ((intermediate_size_per_partition + block_n - 1) // block_n), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ), @@ -930,11 +944,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): # Per-tensor kernels use a single scale, for W13, but on disk there # is a separate scale for W1 and W3. Requantize with the max scale. if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - process_fp8_weight_tensor_strategy_moe( + w13, w13_scale = process_fp8_weight_tensor_strategy_moe( w13, w13_scale, shard_size=layer.intermediate_size_per_partition, num_experts=layer.num_local_experts, + is_act_and_mul=self.moe.is_act_and_mul, ) w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( @@ -1166,12 +1181,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): **extra_weight_attrs, ): params_dtype = torch.int8 + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, - 2 * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, hidden_size, dtype=params_dtype, ), @@ -1196,7 +1212,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL w13_weight_scale = torch.nn.Parameter( torch.ones( - num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + num_experts, + w13_num_shards * intermediate_size_per_partition, + 1, + dtype=torch.float32, ), requires_grad=False, ) @@ -1296,6 +1315,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): **extra_weight_attrs, ): intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will @@ -1307,7 +1327,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): torch.empty( num_experts, hidden_size // self.packed_factor, - 2 * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, dtype=torch.int32, ), requires_grad=False, @@ -1352,7 +1372,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): torch.ones( num_experts, num_groups_w13, - 2 * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, dtype=params_dtype, ), requires_grad=False, @@ -1600,10 +1620,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert layer.activation == "silu", ( - f"{layer.activation} not supported for Marlin MoE." - ) - topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, @@ -1625,6 +1641,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): quant_type_id=self.quant_type.id, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, + activation=layer.activation, expert_map=layer.expert_map, g_idx1=layer.w13_weight_g_idx, g_idx2=layer.w2_weight_g_idx, @@ -1675,11 +1692,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): extra_weight_attrs.update( {"is_transposed": True, "quant_method": self.strategy} ) + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 w13_weight = torch.nn.Parameter( torch.empty( num_experts, hidden_size // self.packed_factor, - 2 * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, dtype=torch.int32, ), requires_grad=False, @@ -1712,7 +1730,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): torch.ones( num_experts, num_groups_w13, - 2 * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, dtype=params_dtype, ), requires_grad=False, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1c0c35bf6..6c3412df8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -637,6 +637,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): block_quant=self.block_quant, tp_size=layer.moe_parallel_config.tp_size, with_lora_support=self.moe.is_lora_enabled, + is_act_and_mul=self.moe.is_act_and_mul, ) if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 68a2c375e..8cb7b83b4 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -900,8 +900,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index bcda7b42c..5eac19a17 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -733,6 +733,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): block_quant=False, tp_size=moe_config.moe_parallel_config.tp_size, with_lora_support=self.moe.is_lora_enabled, + is_act_and_mul=self.moe.is_act_and_mul, ) self.kernel: mk.FusedMoEModularKernel | None = None @@ -789,15 +790,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) weight_loader = extra_weight_attrs.get("weight_loader") - if self.moe.is_act_and_mul: - w13_up_dim = 2 * intermediate_size_per_partition - else: - w13_up_dim = intermediate_size_per_partition + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 w13_weight = ModelWeightParameter( data=torch.empty( num_experts, - w13_up_dim, + w13_num_shards * intermediate_size_per_partition, hidden_size, dtype=weight_dtype, ), @@ -826,7 +824,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): # For non-gated MoE, allocate 1 scale for w13. w13_weight_scale = PerTensorScaleParameter( data=torch.full( - (num_experts, 2 if self.moe.is_act_and_mul else 1), + (num_experts, w13_num_shards), 1.0, dtype=torch.float32, ), @@ -1132,6 +1130,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass": self.backend = "cutlass" assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}" + elif envs.VLLM_NVFP4_GEMM_BACKEND == "marlin": + self.backend = "marlin" + assert is_fp4_marlin_supported(), f"Marlin is required for {self.backend}" if self.backend == "none": raise ValueError( @@ -1337,13 +1338,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): self.quant_config = quant_config self.nvfp4_backend = select_nvfp4_moe_backend() # TODO: move this type of check into the oracle. - if ( - not self.moe.is_act_and_mul - and not self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS - ): + if not self.moe.is_act_and_mul and self.nvfp4_backend not in [ + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.MARLIN, + ]: raise NotImplementedError( "Non-gated activations are only supported by FlashInfer " - "CUTLASS NvFP4 MoE backend." + f"CUTLASS and Marlin NvFP4 MoE backends, not {self.nvfp4_backend}." ) self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( @@ -1409,11 +1410,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): weight_scale_dtype = torch.float8_e4m3fn weight_loader = extra_weight_attrs.get("weight_loader") global_num_experts = extra_weight_attrs.get("global_num_experts") + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 # GEMM 1 w13_weight = ModelWeightParameter( data=torch.empty( num_experts, - (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, dtype=weight_dtype, @@ -1442,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): w13_weight_scale = ModelWeightParameter( data=torch.empty( num_experts, - (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition, + w13_num_shards * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.quant_config.group_size, dtype=weight_scale_dtype, @@ -1472,9 +1474,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ) w13_weight_scale_2 = PerTensorScaleParameter( - data=torch.empty( - num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32 - ), + data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) @@ -1495,7 +1495,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): w13_input_scale = PerTensorScaleParameter( data=torch.empty( global_sf_num_experts, - 2 if self.moe.is_act_and_mul else 1, + w13_num_shards, dtype=torch.float32, ), weight_loader=weight_loader, @@ -1616,6 +1616,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): x=x, router_logits=router_logits, top_k=layer.top_k, + activation=layer.activation, global_num_experts=layer.global_num_experts, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, @@ -1642,6 +1643,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): topk_ids=topk_ids, topk_weights=topk_weights, top_k=layer.top_k, + activation=layer.activation, global_num_experts=layer.global_num_experts, ) else: diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 912ff5a4a..272b13861 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -255,6 +255,7 @@ def flashinfer_trtllm_fp4_moe( x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], router_logits: torch.Tensor, top_k: int, + activation: str, global_num_experts: int, num_expert_group: int | None, topk_group: int | None, @@ -269,6 +270,7 @@ def flashinfer_trtllm_fp4_moe( x: Input tensor router_logits: Router logits for expert selection top_k: Number of experts to select per token + activation: Activation function to use global_num_experts: Total number of experts across all ranks num_expert_group: Number of expert groups (for grouped routing) topk_group: Top-k within each group @@ -282,6 +284,12 @@ def flashinfer_trtllm_fp4_moe( from vllm.model_executor.models.llama4 import Llama4MoE + # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404 + assert activation == "silu", ( + "Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. " + f"{activation} found instead." + ) + # Quantize input to FP4 if isinstance(x, tuple): hidden_states_fp4, hidden_states_scale_linear_fp4 = x @@ -352,6 +360,7 @@ def flashinfer_trtllm_fp4_routed_moe( topk_ids: torch.Tensor, topk_weights: torch.Tensor, top_k: int, + activation: str, global_num_experts: int, ) -> torch.Tensor: """ @@ -364,6 +373,7 @@ def flashinfer_trtllm_fp4_routed_moe( x: Input tensor topk_ids: Ids of selected experts top_k: Number of experts to select per token + activation: Activation function to use global_num_experts: Total number of experts across all ranks Returns: @@ -371,6 +381,12 @@ def flashinfer_trtllm_fp4_routed_moe( """ import flashinfer + # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535 + assert activation == "silu", ( + "Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. " + f"{activation} found instead." + ) + # Pack top k ids and expert weights into a single int32 tensor, as # required by TRT-LLM packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0e21c81f7..f57133aeb 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -233,8 +233,6 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: intermediate_size_per_partition = layer.intermediate_size_per_partition # apply_router_weight_on_input is not supported for moe marlin supports_router_weight = not layer.apply_router_weight_on_input - # moe marlin requires the activation to be silu - supports_activation = layer.activation == "silu" # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # down: (n, k) = (hidden_size, intermediate_size_per_partition) @@ -244,12 +242,7 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: and intermediate_size_per_partition % max(64, group_size) == 0 ) supports_group_size = group_size in [-1, 32, 64, 128] - return ( - supports_shape - and supports_group_size - and supports_router_weight - and supports_activation - ) + return supports_shape and supports_group_size and supports_router_weight def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor): diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 2ced41ef8..db56b84a9 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -235,6 +235,7 @@ def prepare_nvfp4_moe_layer_for_marlin( w2: torch.Tensor, w2_scale: torch.Tensor, w2_scale_2: torch.Tensor, + is_act_and_mul: bool, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor ]: @@ -266,8 +267,9 @@ def prepare_nvfp4_moe_layer_for_marlin( # Repack weights to marlin format def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor: tensor_list = [] + num_shards = 2 if is_act_and_mul else 1 if "w13" in name: - size_n, size_k = N * 2, K + size_n, size_k = N * num_shards, K else: size_n, size_k = K, N @@ -300,8 +302,9 @@ def prepare_nvfp4_moe_layer_for_marlin( g_scales = g_scales.to(param_dtype) tensor_list = [] + num_shards = 2 if is_act_and_mul else 1 if "w13" in name: - size_n, size_k = N * 2, K + size_n, size_k = N * num_shards, K else: size_n, size_k = K, N