From c3a9752b0c11f87677e2ab918e524af7a368c664 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Fri, 30 Jan 2026 10:30:46 -0800 Subject: [PATCH] [Hardware][SM100] Add TRTLLM Kernel for INT4 W4A16 Kernel. (#32437) Signed-off-by: Pavani Majety --- .../moe/test_marlin_vs_trtllm_mxint4.py | 272 ++++++++++++++++++ vllm/envs.py | 11 +- vllm/model_executor/layers/fused_moe/layer.py | 14 +- .../compressed_tensors_moe.py | 183 +++++++++++- .../utils/flashinfer_mxint4_moe.py | 266 +++++++++++++++++ vllm/utils/flashinfer.py | 4 +- 6 files changed, 727 insertions(+), 23 deletions(-) create mode 100644 tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py create mode 100644 vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py diff --git a/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py b/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py new file mode 100644 index 000000000..d6735b126 --- /dev/null +++ b/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py @@ -0,0 +1,272 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test comparing Marlin INT4 MoE vs FlashInfer TRT-LLM MXINT4 MoE.""" + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, +) +from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import ( + grouped_topk, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ( + prepare_static_weights_for_trtllm_mxint4_moe, +) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + + +def mxint4_quantize( + x: torch.Tensor, sf_vec_size: int = 32 +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize BF16 tensor to MXINT4 with block scaling (group_size=sf_vec_size). + + Returns: + - uint8 packed (2 INT4/byte): [..., k//2] - stores SIGNED INT4 [-8, 7] + - scales in BF16: [..., k//sf_vec_size] + """ + x_reshaped = x.reshape(-1, sf_vec_size) + x_max = x_reshaped.max(dim=-1, keepdim=True)[0].to(torch.float32) + x_min = x_reshaped.min(dim=-1, keepdim=True)[0].to(torch.float32) + x_max = x_max * 8.0 / 7.0 + amax = torch.where(x_max > -x_min, x_max, -x_min) + scales = amax / 8.0 + x_scaled = x_reshaped * scales.reciprocal() + x_int8 = ( + x_scaled.round().clamp(-8, 7).to(torch.int8).reshape(-1, sf_vec_size // 2, 2) + ) + x_int4 = (x_int8[..., 0] & 0x0F) | ((x_int8[..., 1] & 0x0F) << 4) + return ( + x_int4.to(torch.uint8).reshape(*x.shape[:-1], x.shape[-1] // 2), + scales.to(x.dtype).reshape(*x.shape[:-1], x.shape[-1] // sf_vec_size), + ) + + +def mxint4_quantize_moe_weights( + weights_bf16: torch.Tensor, group_size: int = 32 +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize MoE weights [e, n, k] to MxInt4 format. + + Args: + weights_bf16: BF16 weights of shape [num_experts, out_features, in_features] + group_size: Quantization group size (default: 32) + + Returns: + - weights_mxint4: Quantized weights [e, n, k//2] uint8 + - scales_mxint4: Quantization scales [e, n, k//group_size] bf16 + """ + e = weights_bf16.shape[0] + weight_list = [] + scale_list = [] + + for i in range(e): + w_q, w_s = mxint4_quantize(weights_bf16[i], sf_vec_size=group_size) + weight_list.append(w_q) + scale_list.append(w_s) + + return torch.stack(weight_list), torch.stack(scale_list) + + +__all__ = [ + "mxint4_quantize", + "mxint4_quantize_moe_weights", + "marlin_quantize_moe_weights", +] + + +def marlin_quantize_moe_weights( + weights_bf16: torch.Tensor, group_size: int = 32 +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize MoE weights [e, n, k] to Marlin INT4 format. + + Args: + weights_bf16: BF16 weights of shape [num_experts, out_features, in_features] + group_size: Quantization group size (default: 32) + + Returns: + - weights_marlin: Marlin quantized weights [e, k//8, n] int32 + - scales_marlin: Marlin quantization scales [e, k//group_size, n] bf16 + """ + from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize, + ) + + e, n, k = weights_bf16.shape + weight_list = [] + scale_list = [] + + for i in range(e): + # Transpose for Marlin: [n, k] → [k, n] + w_t = weights_bf16[i].T.contiguous() + _, w_q, w_s, _, _, _ = marlin_quantize( + w_t, scalar_types.uint4b8, group_size, act_order=False + ) + weight_list.append(w_q) + scale_list.append(w_s) + + # Stack to get [e, ...] shape + weights_marlin = torch.stack(weight_list) # [e, k // 8, n] + scales_marlin = torch.stack(scale_list) # [e, k // group_size, n] + + return weights_marlin, scales_marlin + + +TRTLLM_GEN_AVAILABLE = ( + current_platform.is_cuda() and current_platform.is_device_capability_family(100) +) + + +@pytest.mark.skipif(not TRTLLM_GEN_AVAILABLE, reason="Skip for non SM100") +@pytest.mark.parametrize("m", [1, 33]) +@pytest.mark.parametrize("n", [7168]) +@pytest.mark.parametrize("k", [512]) +@pytest.mark.parametrize("e", [384]) +@pytest.mark.parametrize("topk", [8]) +@pytest.mark.parametrize("group_size", [32]) +def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group_size): + """Compare Marlin INT4 MoE vs FlashInfer TRT-LLM MXINT4 MoE. + + Uses mxint4_quantize() to generate common INT4 weights + BF16 scales, + then runs both Marlin and TRT-LLM kernels and compares outputs. + """ + pytest.importorskip("flashinfer") + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_INT4", "1") + + torch.cuda.manual_seed(0) + + dtype = torch.bfloat16 + + # DeepSeekV3 routing config (from Kimi-K2-Thinking config.json) + n_group = 1 # n_group from model config + topk_group = 1 # topk_group from model config + routed_scaling = 2.827 # routed_scaling_factor from model config + + # Input - realistic activation range for LLM (after LayerNorm: mean~0, std~1) + a = torch.randn((m, k), device="cuda", dtype=dtype) * 0.5 + + # Generate routing logits and bias (DeepSeekV3 expects float logits) + # Realistic ranges: logits typically [-3, 3], bias [-2, 2] + routing_logits = torch.randn((m, e), device="cuda", dtype=torch.float32) * 1.5 + routing_bias = torch.randn(e, device="cuda", dtype=torch.float32) * 0.8 + + # 1. Generate BF16 weights (SHARED between both paths) + # Realistic weight initialization: Xavier/Glorot uniform scaling + # std = sqrt(2 / (fan_in + fan_out)) + std_w1 = (2.0 / (k + 2 * n)) ** 0.5 + std_w2 = (2.0 / (n + k)) ** 0.5 + w1_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) * std_w1 + w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) * std_w2 + + # === Path 1: TRT-LLM FlashInfer MXINT4 MoE === + # Similar to: if self.use_flashinfer_mxint4_moe + # Quantize using MXINT4 method (signed INT4) + w1_int4, w1_scales = mxint4_quantize_moe_weights(w1_bf16, group_size) + w2_int4, w2_scales = mxint4_quantize_moe_weights(w2_bf16, group_size) + + trtllm_weights = prepare_static_weights_for_trtllm_mxint4_moe( + gemm1_weights=w1_int4, + gemm1_scales=w1_scales, + gemm2_weights=w2_int4, + gemm2_scales=w2_scales, + ) + + from flashinfer import RoutingMethodType + from flashinfer.fused_moe import trtllm_mxint4_block_scale_moe + + # Routing handled internally by trtllm_mxint4_block_scale_moe + trtllm_output = trtllm_mxint4_block_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias.to(torch.bfloat16), + hidden_states=a, + gemm1_weights=trtllm_weights["gemm1_weights"], + gemm1_weights_scale=trtllm_weights["gemm1_scales"], + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=trtllm_weights["gemm2_weights"], + gemm2_weights_scale=trtllm_weights["gemm2_scales"], + num_experts=e, + top_k=topk, + n_group=n_group, + topk_group=topk_group, + intermediate_size=n, + local_expert_offset=0, + local_num_experts=e, + routed_scaling_factor=routed_scaling, + routing_method_type=RoutingMethodType.DeepSeekV3, + enable_pdl=None, + output=None, + tune_max_num_tokens=8192, + ).to(dtype) + + # === Path 2: Marlin INT4 MoE === + # Similar to: else (non-flashinfer path) + # Quantize using Marlin's method (UINT4b8) + w1_marlin, w1_scales_marlin = marlin_quantize_moe_weights(w1_bf16, group_size) + w2_marlin, w2_scales_marlin = marlin_quantize_moe_weights(w2_bf16, group_size) + + # Use production routing kernel (same as router.select_experts internally uses) + topk_weights, topk_ids = grouped_topk( + hidden_states=a, + gating_output=routing_logits, + topk=topk, + renormalize=False, # DeepSeekV3 doesn't renormalize + num_expert_group=n_group, + topk_group=topk_group, + scoring_func="sigmoid", # DeepSeekV3 uses sigmoid + routed_scaling_factor=routed_scaling, + e_score_correction_bias=routing_bias, + ) + + marlin_output = fused_marlin_moe( + a, + w1_marlin, + w2_marlin, + None, + None, + w1_scales_marlin, + w2_scales_marlin, + None, # gating_output not needed when topk_weights/ids provided + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=None, + global_scale1=None, + global_scale2=None, + g_idx1=None, + g_idx2=None, + input_global_scale1=None, + input_global_scale2=None, + sort_indices1=None, + sort_indices2=None, + w1_zeros=None, + w2_zeros=None, + input_dtype=dtype, + quant_type_id=scalar_types.uint4b8.id, + is_k_full=True, + ) + + # Sanity check: manually compute BF16 reference for comparison + # Use same routing as Marlin path for consistency + bf16_output = torch.zeros((m, k), device="cuda", dtype=dtype) + for token_idx in range(m): + for expert_rank in range(topk): + expert_id = topk_ids[token_idx, expert_rank].item() + weight = topk_weights[token_idx, expert_rank].item() + # w1: [2*n, k] @ [k] -> [2*n] + up_gate = a[token_idx] @ w1_bf16[expert_id].T # [2*n] + gate, up = up_gate.chunk(2, dim=0) + intermediate = torch.nn.functional.silu(gate) * up # [n] + # w2: [k, n] @ [n] -> [k] + expert_out = intermediate @ w2_bf16[expert_id].T # [k] + bf16_output[token_idx] += weight * expert_out + # Compare against BF16 reference. + torch.testing.assert_close(marlin_output, bf16_output, atol=0.3, rtol=1.0) + torch.testing.assert_close(trtllm_output, bf16_output, atol=0.3, rtol=1.0) + + # Compare against each other for sanity. + # Note: Different quantization schemes (UINT4b8 vs signed MXINT4) cause + # some differences + torch.testing.assert_close(marlin_output, trtllm_output, atol=0.3, rtol=6.0) diff --git a/vllm/envs.py b/vllm/envs.py index 1c9eacae1..741a2163c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -174,6 +174,7 @@ if TYPE_CHECKING: VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False + VLLM_USE_FLASHINFER_MOE_INT4: bool = False VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( "latency" ) @@ -1240,18 +1241,22 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool( int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "0")) ), - # Allow use of FlashInfer MoE kernels for fused moe ops. + # Allow use of FlashInfer BF16 MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0")) ), - # Allow use of FlashInfer MoE kernels for fused moe ops. + # Allow use of FlashInfer FP8 MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0")) ), - # Allow use of FlashInfer CUTLASS kernels for fused moe ops. + # Allow use of FlashInfer NVFP4 MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP4": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0")) ), + # Allow use of FlashInfer MxInt4 MoE kernels for fused moe ops. + "VLLM_USE_FLASHINFER_MOE_INT4": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_INT4", "0")) + ), # If set to 1, use the FlashInfer # MXFP8 (activation) x MXFP4 (weight) MoE backend. "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d095e4074..6af458df5 100755 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1138,6 +1138,11 @@ class FusedMoE(CustomOp): return False if return_success else None # Hereafter, `expert_id` is local physical id + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size_per_partition is + is_transposed = getattr(param, "is_transposed", False) + # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -1145,7 +1150,10 @@ class FusedMoE(CustomOp): "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", ): - loaded_weight = loaded_weight.t().contiguous() + if is_transposed: + loaded_weight = loaded_weight.t().contiguous() + else: + loaded_weight = loaded_weight if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but got {shard_id}.") @@ -1183,10 +1191,6 @@ class FusedMoE(CustomOp): ) return True if return_success else None - # is_transposed: if the dim to shard the weight - # should be flipped. Required by GPTQ, compressed-tensors - # should be whatever dimension intermediate_size_per_partition is - is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: shard_dim = int(not shard_dim) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f26ddfb87..dbfa8fb9b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -63,6 +63,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, ) +from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ( + flashinfer_trtllm_mxint4_moe, + is_flashinfer_mxint4_moe_available, + prepare_static_weights_for_trtllm_mxint4_moe, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, ) @@ -1247,8 +1252,89 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): self.actorder = weight_quant.actorder self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] - self.use_marlin = True + self.marlin_input_dtype = get_marlin_input_dtype(layer_name) + self.use_flashinfer_mxint4_moe = ( + is_flashinfer_mxint4_moe_available() + and self.group_size == 32 + and weight_quant.num_bits == 4 + ) + self.kernel_backend = ( + "Flashinfer" if self.use_flashinfer_mxint4_moe else "Marlin" + ) + logger.info_once( + f"Using {self.kernel_backend} backend for WNA16 MoE " + f"(group_size={self.group_size}, num_bits={self.num_bits})", + scope="local", + ) + + def get_weight_shape( + self, + weight_name: str, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + num_groups_w2: int | None = None, + num_groups_w13: int | None = None, + ) -> tuple[int, int, int]: + """ + Get the shape of the weight based on the weight name, number of experts + hidden size, intermediate size per partition, number of groups for w2, + and number of groups for w13. Pass in num_groups_w2 and num_groups_w13 + for weight scales. + """ + if weight_name == "w13_scale": + assert num_groups_w13 is not None, ( + "num_groups_w13 must be provided for weight scales" + ) + if weight_name == "w2_scale": + assert num_groups_w2 is not None, ( + "num_groups_w2 must be provided for weight scales" + ) + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + shape_map = { + "w13_weight": { + "Flashinfer": ( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size // self.packed_factor, + ), + "Marlin": ( + num_experts, + hidden_size // self.packed_factor, + w13_num_shards * intermediate_size_per_partition, + ), + }, + "w13_scale": { + "Flashinfer": ( + num_experts, + w13_num_shards * intermediate_size_per_partition, + num_groups_w13, + ), + "Marlin": ( + num_experts, + num_groups_w13, + w13_num_shards * intermediate_size_per_partition, + ), + }, + "w2_weight": { + "Flashinfer": ( + num_experts, + hidden_size, + intermediate_size_per_partition // self.packed_factor, + ), + "Marlin": ( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + ), + }, + "w2_scale": { + "Flashinfer": (num_experts, hidden_size, num_groups_w2), + "Marlin": (num_experts, num_groups_w2, hidden_size), + }, + } + return shape_map[weight_name][self.kernel_backend] def create_weights( self, @@ -1260,19 +1346,23 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): **extra_weight_attrs, ): intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") - w13_num_shards = 2 if self.moe.is_act_and_mul else 1 # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will # shard for TP along the transposed dims + is_transposed = self.kernel_backend != "Flashinfer" extra_weight_attrs.update( - {"is_transposed": True, "quant_method": self.strategy} + {"is_transposed": is_transposed, "quant_method": self.strategy} ) + w13_weight = torch.nn.Parameter( torch.empty( - num_experts, - hidden_size // self.packed_factor, - w13_num_shards * intermediate_size_per_partition, + *self.get_weight_shape( + "w13_weight", + num_experts, + hidden_size, + intermediate_size_per_partition, + ), dtype=torch.int32, ), requires_grad=False, @@ -1282,9 +1372,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): w2_weight = torch.nn.Parameter( torch.empty( - num_experts, - intermediate_size_per_partition // self.packed_factor, - hidden_size, + *self.get_weight_shape( + "w2_weight", + num_experts, + hidden_size, + intermediate_size_per_partition, + ), dtype=torch.int32, ), requires_grad=False, @@ -1315,9 +1408,13 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): w13_scale = torch.nn.Parameter( torch.ones( - num_experts, - num_groups_w13, - w13_num_shards * intermediate_size_per_partition, + *self.get_weight_shape( + "w13_scale", + num_experts, + hidden_size, + intermediate_size_per_partition, + num_groups_w13=num_groups_w13, + ), dtype=params_dtype, ), requires_grad=False, @@ -1326,7 +1423,16 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): set_weight_attrs(w13_scale, extra_weight_attrs) w2_scale = torch.nn.Parameter( - torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + torch.ones( + *self.get_weight_shape( + "w2_scale", + num_experts, + hidden_size, + intermediate_size_per_partition, + num_groups_w2=num_groups_w2, + ), + dtype=params_dtype, + ), requires_grad=False, ) layer.register_parameter("w2_weight_scale", w2_scale) @@ -1396,6 +1502,27 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_weight_g_idx.shape[0] device = layer.w13_weight_g_idx.device + if self.kernel_backend == "Flashinfer": + dict_weights_mxint4 = prepare_static_weights_for_trtllm_mxint4_moe( + layer.w13_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_packed, + layer.w2_weight_scale, + ) + replace_parameter( + layer, "w13_weight_packed", dict_weights_mxint4["gemm1_weights"] + ) + replace_parameter( + layer, "w13_weight_scale", dict_weights_mxint4["gemm1_scales"] + ) + replace_parameter( + layer, "w2_weight_packed", dict_weights_mxint4["gemm2_weights"] + ) + replace_parameter( + layer, "w2_weight_scale", dict_weights_mxint4["gemm2_scales"] + ) + return None + is_a_8bit = ( self.marlin_input_dtype is not None and self.marlin_input_dtype.itemsize == 1 @@ -1560,6 +1687,35 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): is_k_full=self.is_k_full, ) + @property + def is_monolithic(self) -> bool: + return self.kernel_backend == "Flashinfer" + + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert self.kernel_backend == "Flashinfer" + return flashinfer_trtllm_mxint4_moe( + x=x, + router_logits=router_logits, + w13_weight_packed=layer.w13_weight_packed, + w13_weight_scale=layer.w13_weight_scale, + w2_weight_packed=layer.w2_weight_packed, + w2_weight_scale=layer.w2_weight_scale, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + intermediate_size_per_partition=layer.intermediate_size_per_partition, + local_num_experts=layer.local_num_experts, + ep_rank=layer.ep_rank, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routing_method_type=layer.routing_method_type, + ) + def apply( self, layer: FusedMoE, @@ -1567,6 +1723,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert self.kernel_backend == "Marlin" return fused_marlin_moe( x, layer.w13_weight_packed, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py new file mode 100644 index 000000000..98a3d1e12 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility helpers for MxInt4 + FlashInfer fused-MoE path""" + +import functools + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe + +__all__ = [ + "prepare_static_weights_for_trtllm_mxint4_moe", + "flashinfer_trtllm_mxint4_moe", + "is_flashinfer_mxint4_moe_available", +] + +logger = init_logger(__name__) + + +@functools.cache +def is_flashinfer_mxint4_moe_available() -> bool: + """Return `True` when FlashInfer MxInt4 kernels can be used.""" + return ( + envs.VLLM_USE_FLASHINFER_MOE_INT4 + and has_flashinfer_trtllm_fused_moe() + and current_platform.is_cuda() + and current_platform.is_device_capability_family(100) + ) + + +def prepare_static_weights_for_trtllm_mxint4_moe( + gemm1_weights: torch.Tensor, + gemm1_scales: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_scales: torch.Tensor, +) -> dict[str, torch.Tensor]: + """ + Prepare MxInt4 weights for TRT-LLM kernel. + + Input: + gemm1_weights: [num_experts, 2*intermediate_size, hidden_size//8] int32 + (checkpoint uint4b8 packed) or uint8 (already packed signed int4) + gemm1_scales: [num_experts, 2*intermediate_size, hidden_size//32] bf16 + gemm2_weights: [num_experts, hidden_size, intermediate_size//8] int32 + (checkpoint uint4b8 packed) or uint8 (already packed signed int4) + gemm2_scales: [num_experts, hidden_size, intermediate_size//32] bf16 + + Returns: + Dict with keys 'gemm1_weights', 'gemm1_scales', 'gemm2_weights', + 'gemm2_scales' containing shuffled/packed tensors ready for kernel + """ + from flashinfer import block_scale_interleave + from flashinfer.fused_moe import ( + convert_to_block_layout, + ) + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, + ) + + from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + reorder_w1w3_to_w3w1, + ) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + convert_packed_uint4b8_to_signed_int4_inplace, + ) + + device = gemm1_weights.device + assert gemm1_weights.ndim == 3, ( + f"Expected a 3D gemm1_weights tensor, got {gemm1_weights.shape}" + ) + assert gemm1_scales.ndim == 3, ( + f"Expected a 3D gemm1_scales tensor, got {gemm1_scales.shape}" + ) + assert gemm2_weights.ndim == 3, ( + f"Expected a 3D gemm2_weights tensor, got {gemm2_weights.shape}" + ) + assert gemm2_scales.ndim == 3, ( + f"Expected a 3D gemm2_scales tensor, got {gemm2_scales.shape}" + ) + + # Convert checkpoint format (uint4b8 in int32) to signed int4 + # Checkpoint stores INT4 as unsigned [0, 15], kernel expects signed [-8, 7] + if gemm1_weights.dtype == torch.int32 and gemm2_weights.dtype == torch.int32: + convert_packed_uint4b8_to_signed_int4_inplace(gemm1_weights) + convert_packed_uint4b8_to_signed_int4_inplace(gemm2_weights) + + gemm1_weights, gemm1_scales = reorder_w1w3_to_w3w1( + gemm1_weights, gemm1_scales, dim=-2 + ) + + _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + num_experts = gemm1_weights.shape[0] + + # Convert quantized weights to proper formats - + gemm1_weights_mxint4 = gemm1_weights.view(torch.uint8) + assert gemm1_scales.dtype == torch.bfloat16 + gemm2_weights_mxint4 = gemm2_weights.view(torch.uint8) + assert gemm2_scales.dtype == torch.bfloat16 + + epilogue_tile_m = 128 + gemm1_weights_mxint4_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_mxint4_shuffled = [] + gemm2_scales_shuffled = [] + + for i in range(num_experts): + # Calculate the permute indices for the following: + # 1. Reorder rows of W1 and scales for fused gated activation + # 2. Shuffle weights and scaling factors for transposed mma output + # for both w3_w1 and w2 weights and scale factors + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + _cache_permute_indices, + gemm1_weights_mxint4[i], + epilogue_tile_m, + ) + gemm1_weights_shuffled = gemm1_weights_mxint4[i][ + permute_indices.to(gemm1_weights.device) + ].contiguous() + permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( + _cache_permute_indices, + gemm1_scales[i], + epilogue_tile_m, + num_elts_per_sf=32, + ).to(device) + gemm1_scales_shuffled.append( + block_scale_interleave(gemm1_scales[i][permute_sf_indices].contiguous()) + ) + + permute_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + gemm2_weights_mxint4[i], + epilogue_tile_m, + ) + gemm2_weights_shuffled = gemm2_weights_mxint4[i][ + permute_indices.to(gemm2_weights.device) + ].contiguous() + + permute_sf_indices = get_w2_permute_indices_with_cache( + _cache_permute_indices, + gemm2_scales[i], + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_shuffled.append( + block_scale_interleave( + gemm2_scales[i][permute_sf_indices.to(gemm2_scales.device)].contiguous() + ) + ) + + block_k = 128 + gemm1_weights_shuffled = convert_to_block_layout( + gemm1_weights_shuffled.view(torch.uint8), block_k + ) + gemm2_weights_shuffled = convert_to_block_layout( + gemm2_weights_shuffled.view(torch.uint8), block_k + ) + + gemm1_weights_mxint4_shuffled.append(gemm1_weights_shuffled) + gemm2_weights_mxint4_shuffled.append(gemm2_weights_shuffled) + + gemm1_weights_mxint4_shuffled = torch.stack(gemm1_weights_mxint4_shuffled) + gemm2_weights_mxint4_shuffled = torch.stack(gemm2_weights_mxint4_shuffled) + gemm1_scales_shuffled = torch.stack(gemm1_scales_shuffled).view(torch.bfloat16) + gemm2_scales_shuffled = torch.stack(gemm2_scales_shuffled).view(torch.bfloat16) + return { + "gemm1_weights": gemm1_weights_mxint4_shuffled, + "gemm1_scales": gemm1_scales_shuffled, + "gemm2_weights": gemm2_weights_mxint4_shuffled, + "gemm2_scales": gemm2_scales_shuffled, + } + + +def flashinfer_trtllm_mxint4_moe( + x: torch.Tensor, + router_logits: torch.Tensor, + w13_weight_packed: torch.Tensor, + w13_weight_scale: torch.Tensor, + w2_weight_packed: torch.Tensor, + w2_weight_scale: torch.Tensor, + global_num_experts: int, + top_k: int, + intermediate_size_per_partition: int, + local_num_experts: int, + ep_rank: int = 0, + num_expert_group: int | None = None, + topk_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routing_method_type: int | None = None, +) -> torch.Tensor: + """ + Apply FlashInfer TensorRT-LLM MxInt4 MoE kernel. + + Args: + x: Input hidden states. dtype: bfloat16 + router_logits: Router logits for expert selection. dtype: bfloat16/float32 + w13_weight_packed: Packed gate+up weights. dtype: uint8 + w13_weight_scale: Scales for gate+up weights. dtype: bfloat16 + w2_weight_packed: Packed down weights. dtype: uint8 + w2_weight_scale: Scales for down weights. dtype: bfloat16 + global_num_experts: Total number of experts across all ranks + top_k: Number of experts to select per token + intermediate_size_per_partition: Intermediate size per partition + local_num_experts: Number of experts on this rank + ep_rank: Expert parallelism rank (default: 0) + num_expert_group: Number of expert groups (default: None -> 0) + topk_group: Top-k within groups (default: None -> 0) + e_score_correction_bias: Optional routing bias. dtype: bfloat16 + routing_method_type: FlashInfer RoutingMethodType enum value + + Returns: + Output tensor from MoE layer. dtype: same as x (bfloat16) + """ + from flashinfer import RoutingMethodType + from flashinfer.fused_moe import trtllm_mxint4_block_scale_moe + + assert x.dtype == torch.bfloat16, f"x dtype must be bfloat16, got {x.dtype}" + assert w13_weight_packed.dtype == torch.uint8, ( + f"w13_weight_packed dtype must be uint8, got {w13_weight_packed.dtype}" + ) + assert w13_weight_scale.dtype == torch.bfloat16, ( + f"w13_weight_scale dtype must be bfloat16, got {w13_weight_scale.dtype}" + ) + assert w2_weight_packed.dtype == torch.uint8, ( + f"w2_weight_packed dtype must be uint8, got {w2_weight_packed.dtype}" + ) + assert w2_weight_scale.dtype == torch.bfloat16, ( + f"w2_weight_scale dtype must be bfloat16, got {w2_weight_scale.dtype}" + ) + + routing_bias = None + if e_score_correction_bias is not None: + routing_bias = e_score_correction_bias.to(torch.bfloat16) + + if routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + out = trtllm_mxint4_block_scale_moe( + routing_logits=router_logits, + routing_bias=routing_bias, + hidden_states=x, + gemm1_weights=w13_weight_packed.data, + gemm1_weights_scale=w13_weight_scale.data, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2_weight_packed.data, + gemm2_weights_scale=w2_weight_scale.data, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group if num_expert_group is not None else 0, + topk_group=topk_group if topk_group is not None else 0, + intermediate_size=intermediate_size_per_partition, + local_expert_offset=ep_rank * local_num_experts, + local_num_experts=local_num_experts, + routed_scaling_factor=None, + routing_method_type=routing_method_type, + enable_pdl=None, + output=None, + tune_max_num_tokens=8192, + ).to(x.dtype) + + return out diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 44fdff247..cf5089247 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -129,12 +129,11 @@ scaled_fp4_grouped_quantize = _lazy_import_wrapper( "flashinfer", "scaled_fp4_grouped_quantize" ) nvfp4_block_scale_interleave = _lazy_import_wrapper( - "flashinfer", "nvfp4_block_scale_interleave" + "flashinfer.fp4_quantization", "block_scale_interleave" ) trtllm_fp4_block_scale_moe = _lazy_import_wrapper( "flashinfer", "trtllm_fp4_block_scale_moe" ) - # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( "flashinfer.autotuner", @@ -196,6 +195,7 @@ def has_flashinfer_trtllm_fused_moe() -> bool: ("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"), ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"), ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), + ("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"), ] for module_name, attr_name in required_functions: mod = _get_submodule(module_name)