diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 10678e376..af47ca91a 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -14,6 +14,7 @@ from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassExpertsFp4, @@ -147,5 +148,130 @@ def test_cutlass_fp4_moe_no_graph( torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1) +# step3.5-flash uses swiglustep activation (clipped SwiGLU with limit=7.0) +# for MoE layers 43-44. This tests the non-fused activation fallback path +# in run_cutlass_moe_fp4 (apply_moe_activation + separate fp4 quantization). +# Model dims: e=288, topk=8, n=1280 (moe_intermediate_size), k=4096 (hidden) +SWIGLUSTEP_MNK_FACTORS = [ + (2, 1280, 4096), + (64, 1280, 4096), + (224, 1280, 4096), +] + + +@pytest.mark.parametrize("m,n,k", SWIGLUSTEP_MNK_FACTORS) +@pytest.mark.parametrize("e", [64, 288]) +@pytest.mark.parametrize("topk", [1, 8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.inference_mode() +def test_cutlass_fp4_moe_swiglustep( + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init +): + set_random_seed(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + quant_blocksize = 16 + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + + (_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = ( + make_test_weights( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, + per_out_ch_quant=False, + ) + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) + + a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) + + assert w1_gs is not None + assert w2_gs is not None + assert w1_blockscale is not None + assert w2_blockscale is not None + + quant_config = nvfp4_moe_quant_config( + g1_alphas=(1 / w1_gs), + g2_alphas=(1 / w2_gs), + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + ) + + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsFp4( + moe_config=make_dummy_moe_config(), + quant_config=quant_config, + ), + inplace=False, + ) + + cutlass_output = kernel( + hidden_states=a, + w1=w1_q, + w2=w2_q, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=MoEActivation.SWIGLUSTEP, + ) + + # Reference: dequantize everything and run torch_moe with swiglustep + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) + ).to(torch.float32) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) + + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize, + ) + + w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize, + ) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize, + ) + + torch_output = torch_moe( + a_in_dtype, + w1_d, + w2_d, + score, + topk, + activation=MoEActivation.SWIGLUSTEP, + ) + + torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1) + + if __name__ == "__main__": test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 4f8948778..ae9430d29 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -690,10 +690,14 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_activation(activation: MoEActivation) -> bool: + # SILU uses a fused silu+mul+fp4_quant kernel path. + # Other gated activations use the generic apply_moe_activation() + # fallback + separate fp4 quantization in run_cutlass_moe_fp4(). return activation in [ MoEActivation.SILU, MoEActivation.GELU, MoEActivation.SWIGLUOAI, + MoEActivation.SWIGLUSTEP, ] @staticmethod diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e5f32ebd1..4a8f31255 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -586,10 +586,13 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_activation(activation: MoEActivation) -> bool: + # Marlin uses apply_moe_activation() callback for activation, + # so any activation supported there can be used here. return activation in [ MoEActivation.SILU, MoEActivation.GELU, MoEActivation.SWIGLUOAI, + MoEActivation.SWIGLUSTEP, MoEActivation.SILU_NO_MUL, MoEActivation.GELU_NO_MUL, MoEActivation.RELU2_NO_MUL, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 0fecc7bbc..097d0bc01 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -652,9 +652,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic - assert layer.activation == MoEActivation.SILU, ( - f"Only SiLU activation is supported, not {layer.activation}." - ) # EPLB path if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 195cfcedd..fcdd770fe 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jurassic model.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from typing import Any import torch @@ -231,6 +232,7 @@ class Step3p5Attention(nn.Module): hidden_size, self.total_num_heads, bias=False, + quant_config=quant_config, prefix=f"{prefix}.g_proj", ) @@ -640,12 +642,22 @@ class Step3p5Model(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + # Old packed 3D format: .moe.gate_proj.weight [num_experts, out, in] expert_params_mapping = [ (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), ] + # New per-expert format: .moe.experts.E.gate_proj.weight_packed [out, in] + per_expert_mapping = FusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.moe_num_experts, + ) + disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: @@ -668,6 +680,54 @@ class Step3p5Model(nn.Module): if layer_idx >= config.num_hidden_layers: continue + # Per-expert MoE weights (new format from LLM Compressor): + # .moe.experts.{E}.{gate,up,down}_proj.{weight_packed,scale,...} + # Each weight is individual per-expert, not stacked 3D. + if ".moe.experts." in local_name: + is_expert_weight = False + for mapping in per_expert_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in local_name: + continue + is_expert_weight = True + name_mapped = local_name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): + continue + if name_mapped not in params_dict: + continue + param = params_dict[name_mapped] + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + loaded_params.add(name_mapped) + break + else: + if ( + not is_expert_weight + and not is_pp_missing_parameter(local_name, self) + and local_name in params_dict + ): + # Not an expert proj — use default loader + # (e.g. share_expert weights if they matched) + param = params_dict[local_name] + weight_loader = getattr( + param, + "weight_loader", + default_weight_loader, + ) + weight_loader(param, loaded_weight) + loaded_params.add(local_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in local_name: continue @@ -703,6 +763,16 @@ class Step3p5Model(nn.Module): param = params_dict[replaced_name] weight_loader = param.weight_loader moe_expert_num = self.moe_num_experts + # Per-tensor global scales (e.g. weight_global_scale) + # have shape [1] in compressed-tensors NVFP4 checkpoints. + # Expand to per-expert before the iteration loop. + if ( + loaded_weight.shape[0] == 1 + and loaded_weight.shape[0] != moe_expert_num + ): + loaded_weight = loaded_weight.expand( + moe_expert_num, *loaded_weight.shape[1:] + ) assert loaded_weight.shape[0] == moe_expert_num for expert_id in range(moe_expert_num): loaded_weight_expert = loaded_weight[expert_id]