From 8c11001ba2110134df3c3aecc63f2559ad1f5996 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Date: Thu, 15 Jan 2026 12:13:08 -0800 Subject: [PATCH] [ROCM] DSfp4 mla projection gemms weight dynamic quantization (#32238) Signed-off-by: Aleksandr Malyshev Co-authored-by: Aleksandr Malyshev --- vllm/_aiter_ops.py | 30 +++++ vllm/envs.py | 6 + .../layers/attention/mla_attention.py | 127 +++++++----------- .../layers/quantization/quark/utils.py | 14 ++ 4 files changed, 97 insertions(+), 80 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c9ad8a5ae..2a247c6d5 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -837,6 +837,7 @@ class rocm_aiter_ops: _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE @@ -863,6 +864,7 @@ class rocm_aiter_ops: cls._SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS @@ -923,6 +925,11 @@ class rocm_aiter_ops: def is_fp8bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP8BMM_ENABLED + @classmethod + @if_aiter_supported + def is_fp4bmm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._FP4BMM_ENABLED + @classmethod @if_aiter_supported def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: @@ -1403,6 +1410,29 @@ class rocm_aiter_ops: query = query.view(query_shape) key = key.view(key_shape) + @staticmethod + def batched_gemm_a16wfp4( + X: torch.Tensor, + W: torch.Tensor, + w_scale: torch.Tensor, + Y: torch.Tensor, + transpose_bm: bool | None = False, + prequant: bool | None = False, + y_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + # ruff: noqa: E501 # isort: skip + from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 + + return batched_gemm_a16wfp4( + X, + W, + w_scale, + y=Y, + transpose_bm=transpose_bm, + prequant=prequant, + y_scale=y_scale, + ) + @staticmethod def triton_fp8_bmm( X: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index 65bbd29f3..f1ee13e33 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -121,6 +121,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True + VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True @@ -991,6 +992,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") ), + # Whether to use aiter triton fp4 bmm kernel + # By default is enabled. + "VLLM_ROCM_USE_AITER_FP4BMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "True").lower() in ("true", "1") + ), # Use AITER triton unified attention for V1 attention "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index a1bce65ce..e9cfa4a08 100755 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1183,6 +1183,12 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): self.q_pad_num_heads = q_pad_num_heads self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + self.is_aiter_triton_fp4_bmm_enabled = ( + rocm_aiter_ops.is_fp4bmm_enabled() + and self.kv_b_proj.weight.dtype == torch.bfloat16 + ) + def process_weights_after_loading(self, act_dtype: torch.dtype): # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform @@ -1211,7 +1217,21 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) - if self.is_aiter_triton_fp8_bmm_enabled: + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + if self.is_aiter_triton_fp4_bmm_enabled: + from vllm.model_executor.layers.quantization.quark.utils import ( + quark_quantize_weight_to_mxfp4, + ) + + self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK) + # Convert from (L, N, P) to (N, L, P) + self.W_K = self.W_K.transpose(0, 1) + self.W_K_scale = self.W_K_scale.transpose(0, 1) + + self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4( + W_UV.permute(1, 2, 0) + ) + elif self.is_aiter_triton_fp8_bmm_enabled: W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( @@ -1261,16 +1281,26 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - - if self.is_aiter_triton_fp8_bmm_enabled: - out = out.view(-1, self.num_heads, self.v_head_dim) + out = out.view(-1, self.num_heads, self.v_head_dim) + if self.is_aiter_triton_fp4_bmm_enabled: + out = rocm_aiter_ops.batched_gemm_a16wfp4( + x, + self.W_V, + self.W_V_scale, + out, + transpose_bm=True, + prequant=True, + y_scale=None, + ) + x = out.view(-1, self.num_heads * self.v_head_dim) + elif self.is_aiter_triton_fp8_bmm_enabled: # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = rocm_aiter_ops.triton_fp8_bmm( x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out ) else: # Convert from (B, N * V) to (N, B, V) - out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + out = out.transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" @@ -1579,80 +1609,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() - def process_weights_after_loading(self, act_dtype: torch.dtype): - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights( - self.kv_b_proj, out_dtype=act_dtype - ).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - ), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}" - ) - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - - if self.is_aiter_triton_fp8_bmm_enabled: - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype() - ) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype() - ) - - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) - - for m in pre_compilation_list: - x = torch.empty( - (self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - ) - - x = torch.empty( - (self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - ) - else: - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) - def _concat_k_nope_k_pe( self, k_nope: torch.Tensor, k_pe: torch.Tensor ) -> torch.Tensor: @@ -2033,7 +1989,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded - if self.is_aiter_triton_fp8_bmm_enabled: + if self.is_aiter_triton_fp4_bmm_enabled: + from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 + + decode_ql_nope = batched_gemm_a16wfp4( + decode_q_nope, + self.W_K, + self.W_K_scale, + transpose_bm=True, + prequant=True, + y_scale=layer._q_scale if fp8_attention else None, + ) + elif self.is_aiter_triton_fp8_bmm_enabled: # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( decode_q_nope, diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index dc82f94eb..98ac1a4f3 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -6,6 +6,7 @@ from types import MappingProxyType from typing import Any import regex as re +import torch def deep_compare(dict1: Any, dict2: Any) -> bool: @@ -103,3 +104,16 @@ def _is_equal_or_regex_match( elif target == value: return True return False + + +# utility for tensor dims > 2 cases +def quark_quantize_weight_to_mxfp4(w: torch.Tensor): + assert w.dtype == torch.bfloat16, ( + "Quark dynamic quantization is supported only for fp16 weights and only to MXF4" + ) + + from aiter.ops.triton.quant import dynamic_mxfp4_quant + + *dims, d = w.shape + w, w_scales = dynamic_mxfp4_quant(w.reshape(-1, d)) + return w.view(*dims, d // 2), w_scales.view(*dims, d // 32)