[ROCM] DSfp4 mla projection gemms weight dynamic quantization (#32238)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
committed by
GitHub
parent
bd292be0c0
commit
8c11001ba2
@@ -837,6 +837,7 @@ class rocm_aiter_ops:
|
|||||||
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||||
# TODO: Consolidate under _LINEAR_ENABLED
|
# TODO: Consolidate under _LINEAR_ENABLED
|
||||||
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||||
|
_FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
|
||||||
# TODO: Consolidate under _LINEAR_ENABLED
|
# TODO: Consolidate under _LINEAR_ENABLED
|
||||||
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||||
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
|
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
|
||||||
@@ -863,6 +864,7 @@ class rocm_aiter_ops:
|
|||||||
cls._SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
|
cls._SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
|
||||||
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||||
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||||
|
cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
|
||||||
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||||
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||||
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||||
@@ -923,6 +925,11 @@ class rocm_aiter_ops:
|
|||||||
def is_fp8bmm_enabled(cls) -> bool:
|
def is_fp8bmm_enabled(cls) -> bool:
|
||||||
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
|
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_fp4bmm_enabled(cls) -> bool:
|
||||||
|
return cls._AITER_ENABLED and cls._FP4BMM_ENABLED
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
|
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
|
||||||
@@ -1403,6 +1410,29 @@ class rocm_aiter_ops:
|
|||||||
query = query.view(query_shape)
|
query = query.view(query_shape)
|
||||||
key = key.view(key_shape)
|
key = key.view(key_shape)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def batched_gemm_a16wfp4(
|
||||||
|
X: torch.Tensor,
|
||||||
|
W: torch.Tensor,
|
||||||
|
w_scale: torch.Tensor,
|
||||||
|
Y: torch.Tensor,
|
||||||
|
transpose_bm: bool | None = False,
|
||||||
|
prequant: bool | None = False,
|
||||||
|
y_scale: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# ruff: noqa: E501 # isort: skip
|
||||||
|
from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4
|
||||||
|
|
||||||
|
return batched_gemm_a16wfp4(
|
||||||
|
X,
|
||||||
|
W,
|
||||||
|
w_scale,
|
||||||
|
y=Y,
|
||||||
|
transpose_bm=transpose_bm,
|
||||||
|
prequant=prequant,
|
||||||
|
y_scale=y_scale,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def triton_fp8_bmm(
|
def triton_fp8_bmm(
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
||||||
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
|
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
|
||||||
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
||||||
|
VLLM_ROCM_USE_AITER_FP4BMM: bool = True
|
||||||
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
||||||
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False
|
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False
|
||||||
VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True
|
VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True
|
||||||
@@ -991,6 +992,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
|
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
|
||||||
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
|
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
|
||||||
),
|
),
|
||||||
|
# Whether to use aiter triton fp4 bmm kernel
|
||||||
|
# By default is enabled.
|
||||||
|
"VLLM_ROCM_USE_AITER_FP4BMM": lambda: (
|
||||||
|
os.getenv("VLLM_ROCM_USE_AITER_FP4BMM", "True").lower() in ("true", "1")
|
||||||
|
),
|
||||||
# Use AITER triton unified attention for V1 attention
|
# Use AITER triton unified attention for V1 attention
|
||||||
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
|
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
|
||||||
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
|
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
|
||||||
|
|||||||
@@ -1183,6 +1183,12 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
|||||||
self.q_pad_num_heads = q_pad_num_heads
|
self.q_pad_num_heads = q_pad_num_heads
|
||||||
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||||
|
|
||||||
|
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
|
||||||
|
self.is_aiter_triton_fp4_bmm_enabled = (
|
||||||
|
rocm_aiter_ops.is_fp4bmm_enabled()
|
||||||
|
and self.kv_b_proj.weight.dtype == torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
# we currently do not have quantized bmm's which are needed for
|
# we currently do not have quantized bmm's which are needed for
|
||||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||||
@@ -1211,7 +1217,21 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
|||||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
|
||||||
|
if self.is_aiter_triton_fp4_bmm_enabled:
|
||||||
|
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||||
|
quark_quantize_weight_to_mxfp4,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK)
|
||||||
|
# Convert from (L, N, P) to (N, L, P)
|
||||||
|
self.W_K = self.W_K.transpose(0, 1)
|
||||||
|
self.W_K_scale = self.W_K_scale.transpose(0, 1)
|
||||||
|
|
||||||
|
self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4(
|
||||||
|
W_UV.permute(1, 2, 0)
|
||||||
|
)
|
||||||
|
elif self.is_aiter_triton_fp8_bmm_enabled:
|
||||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||||
@@ -1261,16 +1281,26 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
|||||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
|
out = out.view(-1, self.num_heads, self.v_head_dim)
|
||||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
if self.is_aiter_triton_fp4_bmm_enabled:
|
||||||
out = out.view(-1, self.num_heads, self.v_head_dim)
|
out = rocm_aiter_ops.batched_gemm_a16wfp4(
|
||||||
|
x,
|
||||||
|
self.W_V,
|
||||||
|
self.W_V_scale,
|
||||||
|
out,
|
||||||
|
transpose_bm=True,
|
||||||
|
prequant=True,
|
||||||
|
y_scale=None,
|
||||||
|
)
|
||||||
|
x = out.view(-1, self.num_heads * self.v_head_dim)
|
||||||
|
elif self.is_aiter_triton_fp8_bmm_enabled:
|
||||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||||
x = rocm_aiter_ops.triton_fp8_bmm(
|
x = rocm_aiter_ops.triton_fp8_bmm(
|
||||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
|
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Convert from (B, N * V) to (N, B, V)
|
# Convert from (B, N * V) to (N, B, V)
|
||||||
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
out = out.transpose(0, 1)
|
||||||
|
|
||||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||||
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
||||||
@@ -1579,80 +1609,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||||
return attn_out, lse.transpose(0, 1).contiguous()
|
return attn_out, lse.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
||||||
# we currently do not have quantized bmm's which are needed for
|
|
||||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
|
||||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
|
||||||
kv_b_proj_weight = get_and_maybe_dequant_weights(
|
|
||||||
self.kv_b_proj, out_dtype=act_dtype
|
|
||||||
).T
|
|
||||||
assert kv_b_proj_weight.shape == (
|
|
||||||
self.kv_lora_rank,
|
|
||||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
||||||
), (
|
|
||||||
f"{kv_b_proj_weight.shape=}, "
|
|
||||||
f"{self.kv_lora_rank=}, "
|
|
||||||
f"{self.num_heads=}, "
|
|
||||||
f"{self.qk_nope_head_dim=}, "
|
|
||||||
f"{self.v_head_dim=}"
|
|
||||||
)
|
|
||||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
|
||||||
self.kv_lora_rank,
|
|
||||||
self.num_heads,
|
|
||||||
self.qk_nope_head_dim + self.v_head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
W_UK, W_UV = kv_b_proj_weight.split(
|
|
||||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
|
||||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
|
||||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
|
||||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
|
||||||
W_K, dtype=current_platform.fp8_dtype()
|
|
||||||
)
|
|
||||||
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
|
|
||||||
W_V, dtype=current_platform.fp8_dtype()
|
|
||||||
)
|
|
||||||
|
|
||||||
# The kernel operates on non-padded inputs. Hence, pre-compiling
|
|
||||||
# triton kernel to avoid runtime compilation for unseen batch sizes
|
|
||||||
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
|
|
||||||
# On DS-R1, this step adds roughly 50s to the model loading time.
|
|
||||||
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
|
|
||||||
pre_compilation_list = list(range(1, max_batch_size + 1))
|
|
||||||
if is_global_first_rank():
|
|
||||||
pre_compilation_list = tqdm(
|
|
||||||
pre_compilation_list,
|
|
||||||
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
|
|
||||||
total=max_batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
for m in pre_compilation_list:
|
|
||||||
x = torch.empty(
|
|
||||||
(self.W_K.shape[0], m, self.W_K.shape[2]),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
device=self.W_K.device,
|
|
||||||
)
|
|
||||||
rocm_aiter_ops.triton_fp8_bmm(
|
|
||||||
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
|
|
||||||
)
|
|
||||||
|
|
||||||
x = torch.empty(
|
|
||||||
(self.W_V.shape[0], m, self.W_V.shape[2]),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
device=self.W_V.device,
|
|
||||||
)
|
|
||||||
rocm_aiter_ops.triton_fp8_bmm(
|
|
||||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Convert from (L, N, V) to (N, L, V)
|
|
||||||
self.W_UV = W_UV.transpose(0, 1)
|
|
||||||
# Convert from (L, N, P) to (N, P, L)
|
|
||||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
|
||||||
|
|
||||||
def _concat_k_nope_k_pe(
|
def _concat_k_nope_k_pe(
|
||||||
self, k_nope: torch.Tensor, k_pe: torch.Tensor
|
self, k_nope: torch.Tensor, k_pe: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -2033,7 +1989,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
decode_pe_padded.copy_(decode_q_pe)
|
decode_pe_padded.copy_(decode_q_pe)
|
||||||
decode_q_pe = decode_pe_padded
|
decode_q_pe = decode_pe_padded
|
||||||
|
|
||||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
if self.is_aiter_triton_fp4_bmm_enabled:
|
||||||
|
from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4
|
||||||
|
|
||||||
|
decode_ql_nope = batched_gemm_a16wfp4(
|
||||||
|
decode_q_nope,
|
||||||
|
self.W_K,
|
||||||
|
self.W_K_scale,
|
||||||
|
transpose_bm=True,
|
||||||
|
prequant=True,
|
||||||
|
y_scale=layer._q_scale if fp8_attention else None,
|
||||||
|
)
|
||||||
|
elif self.is_aiter_triton_fp8_bmm_enabled:
|
||||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||||
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
|
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
|
||||||
decode_q_nope,
|
decode_q_nope,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from types import MappingProxyType
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||||
@@ -103,3 +104,16 @@ def _is_equal_or_regex_match(
|
|||||||
elif target == value:
|
elif target == value:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# utility for tensor dims > 2 cases
|
||||||
|
def quark_quantize_weight_to_mxfp4(w: torch.Tensor):
|
||||||
|
assert w.dtype == torch.bfloat16, (
|
||||||
|
"Quark dynamic quantization is supported only for fp16 weights and only to MXF4"
|
||||||
|
)
|
||||||
|
|
||||||
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||||
|
|
||||||
|
*dims, d = w.shape
|
||||||
|
w, w_scales = dynamic_mxfp4_quant(w.reshape(-1, d))
|
||||||
|
return w.view(*dims, d // 2), w_scales.view(*dims, d // 32)
|
||||||
|
|||||||
Reference in New Issue
Block a user