diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 8ef34bfd6..c8366ecce 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -87,6 +87,10 @@ def _rocm_aiter_fused_moe_impl( a2_scale: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None, output_dtype: torch.dtype | None = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: torch.Tensor | None = None, + bias2: torch.Tensor | None = None, ) -> torch.Tensor: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe @@ -110,6 +114,10 @@ def _rocm_aiter_fused_moe_impl( a2_scale, num_local_tokens=num_local_tokens, dtype=output_dtype, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + bias1=bias1, + bias2=bias2, ) @@ -307,6 +315,28 @@ def _rocm_aiter_grouped_topk_fake( pass +def _rocm_aiter_fused_topk_impl( + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + gate_up: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.fused_moe import fused_topk + + # fused_topk returns (topk_weights, topk_indices) + return fused_topk(x, router_logits, top_k, gate_up) + + +def _rocm_aiter_fused_topk_fake( + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + gate_up: bool, +) -> None: + # tuple[torch.Tensor, torch.Tensor]: + pass + + # Cache whether aiter supports FP8 MLA parameters _AITER_MLA_SUPPORTS_FP8: bool | None = None @@ -994,6 +1024,70 @@ class rocm_aiter_ops: cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + @staticmethod + def get_aiter_activation_type(activation_str: str): + """ + Given an activation type as a string, returns the corresponding aiter ActivationType enum. + Supported activation types: "no", "none", "silu", "gelu", "swiglu". + Returns None if the mapping fails. + + Args: + activation_str (str): Activation type as string. + + Returns: + Aiter ActivationType enum value, or None if not found. + """ + # Import only locally, since aiter may not always be available. + try: + from aiter import ActivationType + except ImportError: + return None + + if not isinstance(activation_str, str): + return None + + name = activation_str.strip().lower() + mapping = { + "none": ActivationType.No, + "no": ActivationType.No, + "silu": ActivationType.Silu, + "gelu": ActivationType.Gelu, + "swiglu": ActivationType.Swiglu, + } + return mapping.get(name) + + @staticmethod + def get_aiter_quant_type(quant_type_str: str): + """ + Given a quantization type as a string, returns the corresponding aiter QuantType enum. + Supported quantization types: "no", "per_tensor", "per_token", "per_1x32", "per_1x128", "per_128x128". + Returns None if the mapping fails. + + Args: + quant_type_str (str): Quantization type as string. + + Returns: + Aiter QuantType enum value, or None if not found. + """ + try: + from aiter import QuantType + except ImportError: + return None + + if not isinstance(quant_type_str, str): + return None + + name = quant_type_str.strip().lower() + mapping = { + "no": QuantType.No, + "per_tensor": QuantType.per_Tensor, + "per_token": QuantType.per_Token, + "per_1x32": QuantType.per_1x32, + "per_1x128": QuantType.per_1x128, + "per_128x128": QuantType.per_128x128, + } + return mapping.get(name) + @classmethod @if_aiter_supported def is_enabled(cls) -> bool: @@ -1127,6 +1221,14 @@ class rocm_aiter_ops: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_topk", + op_func=_rocm_aiter_fused_topk_impl, + mutates_args=[], + fake_impl=_rocm_aiter_fused_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_mla_decode_fwd", op_func=_rocm_aiter_mla_decode_fwd_impl, @@ -1360,6 +1462,10 @@ class rocm_aiter_ops: a2_scale: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None, output_dtype: torch.dtype | None = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: torch.Tensor | None = None, + bias2: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_fused_moe( hidden_states, @@ -1377,6 +1483,10 @@ class rocm_aiter_ops: a2_scale, num_local_tokens, output_dtype, + hidden_pad, + intermediate_pad, + bias1, + bias2, ) @staticmethod @@ -1481,6 +1591,15 @@ class rocm_aiter_ops: routed_scaling_factor, ) + @staticmethod + def fused_topk( + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + gate_up: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_fused_topk(x, router_logits, top_k, gate_up) + @staticmethod def mla_decode_fwd( q: torch.Tensor, @@ -1701,6 +1820,47 @@ class rocm_aiter_ops: return shuffle_weight(tensor, layout=layout) + @staticmethod + def shuffle_weight_a16w4( + tensor: "torch.Tensor", + nLane: int, + gate_up: bool, + ) -> "torch.Tensor": + """ + Shuffles the weight tensor into (A16W4) layout for AITER kernels. + + Args: + tensor: The input weight tensor to be shuffled. + layout: The block layout to use, defaults to (16, 4). + + Returns: + torch.Tensor: The shuffled tensor. + """ + from aiter.ops.shuffle import shuffle_weight_a16w4 + + return shuffle_weight_a16w4(tensor, nLane, gate_up) + + @staticmethod + def shuffle_scale_a16w4( + tensor: "torch.Tensor", + num_experts: int, + gate_up: bool, + ) -> "torch.Tensor": + """ + Shuffles the scale tensor into (A16W4) layout for AITER kernels. + + Args: + tensor: The input scale tensor to be shuffled. + num_experts: Number of experts, needed for reshaping logic. + gate_up: Whether the scale is for w13 (True) or w2 (False). + + Returns: + torch.Tensor: The shuffled scale tensor. + """ + from aiter.ops.shuffle import shuffle_scale_a16w4 + + return shuffle_scale_a16w4(tensor, num_experts, gate_up) + @staticmethod def shuffle_weights( *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 9318bedff..29dd03596 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -6,6 +6,7 @@ import torch from torch.nn.parameter import Parameter from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -77,6 +78,8 @@ class Mxfp4Backend(Enum): # Triton Backend TRITON = 6 + CK = 7 + def get_mxfp4_backend_with_lora() -> Mxfp4Backend: """ @@ -167,9 +170,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: elif current_platform.is_xpu(): logger.info_once("Using xpu backend on XPU") return Mxfp4Backend.MARLIN - elif current_platform.is_rocm() and has_triton_kernels(): - logger.info_once("Using Triton backend") - return Mxfp4Backend.TRITON + elif current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx950 + + if rocm_aiter_ops.is_enabled() and on_gfx950(): + logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)") + return Mxfp4Backend.CK + elif has_triton_kernels(): + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON return Mxfp4Backend.NONE @@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.intermediate_size = intermediate_size_per_partition_after_pad self.hidden_size = hidden_size + self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0) + self.intermediate_pad = ( + intermediate_size_per_partition_after_pad - intermediate_size_per_partition + ) # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.zeros( @@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ), shared_experts=None, ) + elif self.mxfp4_backend == Mxfp4Backend.CK: + if layer.w13_bias is not None: + layer.w13_bias.data = layer.w13_bias.data.to(torch.float32) + if layer.w2_bias.data is not None: + layer.w2_bias.data = layer.w2_bias.data.to(torch.float32) + + e, n, k = layer.w13_weight.shape + layer.w13_weight.view(torch.uint8).copy_( + layer.w13_weight.data.view(torch.uint8) + .view(e, n // 2, 2, k) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, k) + ) + layer.w13_weight_scale.data = ( + layer.w13_weight_scale.data.view(e, n // 2, 2, -1) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, -1) + ) + layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2) + layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2) + + layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4( + layer.w13_weight, 16, True + ) + shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4( + layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]), + self.num_experts, + True, + ) + + layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4( + layer.w2_weight, 16, False + ) + shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4( + layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]), + self.num_experts, + False, + ) + + layer.w13_bias.data = ( + layer.w13_bias.data.view(-1, n // 2, 2) + .permute(0, 2, 1) + .contiguous() + .view(-1, n) + ) + + layer.w13_weight_scale = torch.nn.Parameter( + shuffled_w13_scale, requires_grad=False + ) + layer.w2_weight_scale = torch.nn.Parameter( + shuffled_w2_scale, requires_grad=False + ) + # replace_parameter(layer, "w13_bias", w13_bias) + # replace_parameter(layer, "w13_weight_scale", w13_weight_scale) + # replace_parameter(layer, "w2_weight_scale", w2_weight_scale) + # replace_parameter(layer, "w13_weight", w13_weight) + # replace_parameter(layer, "w2_weight", w2_weight) + elif self.mxfp4_backend == Mxfp4Backend.TRITON: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -792,7 +865,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w13_bias = Parameter(w13_bias, requires_grad=False) layer.w2_bias = Parameter(w2_bias, requires_grad=False) - # Ideally we'd use FusedMoEModularKernel.prepare_finalize object # (stored in self.fused_experts) to determine if the MoE has a # batched activation format. As self.fused_experts is not @@ -803,7 +875,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: num_warps = 8 - w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( layer.w13_weight, layer.w13_weight_scale, num_warps ) @@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.w2_precision_config = PrecisionConfig( weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) ) - self.w13_weight = w13_weight self.w2_weight = w2_weight del layer.w13_weight del layer.w2_weight layer.w13_weight = w13_weight layer.w2_weight = w2_weight + else: raise ValueError( f"Unsupported mxfp4_backend: {self.mxfp4_backend}: " @@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): elif self.mxfp4_backend in [ Mxfp4Backend.SM100_FI_MXFP4_BF16, Mxfp4Backend.SM90_FI_MXFP4_BF16, + Mxfp4Backend.CK, ]: return mxfp4_w4a16_moe_quant_config( w1_bias=layer.w13_bias, @@ -933,6 +1005,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.TRITON + or self.mxfp4_backend == Mxfp4Backend.CK ) def apply( @@ -1054,6 +1127,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): tune_max_num_tokens=max(self.max_capture_size, 1), )[0] return trtllm_gen_output + elif self.mxfp4_backend == Mxfp4Backend.CK: + topk_weights, topk_ids = rocm_aiter_ops.fused_topk( + x, router_logits, layer.top_k, True + ) + output = rocm_aiter_ops.fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"), + quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"), + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + doweight_stage1=False, + hidden_pad=self.hidden_pad // 128 * 128, + intermediate_pad=self.intermediate_pad // 64 * 64 * 2, + bias1=layer.w13_bias, + bias2=layer.w2_bias, + ) + return output elif self.mxfp4_backend == Mxfp4Backend.TRITON: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward,