[Attention] Deepseek v3 MLA support with FP8 compute (#12601)

This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights 

---------

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: simon-mo <simon.mo@hey.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
This commit is contained in:
Lucas Wilkinson
2025-02-01 00:52:51 -05:00
committed by GitHub
parent 3e1c76cf3a
commit baeded2569
10 changed files with 579 additions and 84 deletions

View File

@@ -1,17 +1,29 @@
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional from typing import Any, Dict, Generic, List, Optional, Tuple
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer, from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl, T) MLAAttentionImpl, T)
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -25,11 +37,11 @@ class MLACommonMetadata(AttentionMetadata):
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
""" """
Common class for implementing repeated parts Common class for implementing repeated parts
Main reference: DeepseekV2 paper, and FlashInfer Implementation Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way: Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the entire KV cache. * Use a single latent vector to represent the entire KV cache.
* The attention "simulates" a multi-head attention, while the compute is * The attention "simulates" a multi-head attention, while the compute is
@@ -46,7 +58,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
* V: V head dim. * V: V head dim.
* kv_c: latent/compressed KV * kv_c: latent/compressed KV
* q_c: latent/compressed Q * q_c: latent/compressed Q
# #
# Outside the MLA attention backend # Outside the MLA attention backend
# #
@@ -55,21 +67,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_c_k_pe (B, Lkv+R). kv_c_k_pe (B, Lkv+R).
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
and kv_c are normalized. and kv_c are normalized.
# #
# Inside the MLA attention backend # Inside the MLA attention backend
# #
* if prefill: * if prefill:
3. The q_c is then projected up into the multi-head version. 3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R). (B, N, P) and q_pe (B, N, R).
4. q_pe, k_pe are then passed through rotary embeddings. 4. q_pe, k_pe are then passed through rotary embeddings.
5. kv_c and k_pe are concatenated and inserted into the cache 5. kv_c and k_pe are concatenated and inserted into the cache
6. The kv_c is then projected up into the multi-head version. 6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P) dimensions for K and V, which is split into k_nope (B, N, P)
and v (B, N, V). and v (B, N, V).
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
q_nope, q_pe, k_nope, k_pe. q_nope, q_pe, k_nope, k_pe.
@@ -112,7 +124,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
From @tsu-bin's calculation, we only want to use the absorption technique From @tsu-bin's calculation, we only want to use the absorption technique
for decode. The prefill algorithm should still use the up-projected MHA for decode. The prefill algorithm should still use the up-projected MHA
for less flops and memory usage. for less flops and memory usage.
""" """
def __init__( def __init__(
@@ -162,8 +174,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
return self.o_proj_absorbed( if is_fp8(self.W_UV_O):
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0] output_parallel = apply_fp8_linear_generic(
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape)
else:
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O)
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else: else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1, return self.o_proj(x.reshape(-1,
@@ -171,6 +194,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def _q_proj_and_k_up_proj(self, x): def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_fp8_linear_generic(
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\ return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank) .view(-1, self.num_heads, self.kv_lora_rank)
else: else:
@@ -179,8 +208,91 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank) .view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self): def process_weights_after_loading(self, act_dtype: torch.dtype):
kv_b_proj_weight = self.kv_b_proj.weight.T
def is_layer_fp8(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, Fp8LinearMethod) or\
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant is not None:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_scales(layer: LinearBase) -> torch.Tensor:
if hasattr(layer, "weight_scale_inv"):
return layer.weight_scale_inv
return layer.weight_scale
def get_and_maybe_dequant_weights(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer)
return scaled_dequantize(weight, scales,
weight_scale_group_shape)
else:
return layer.weight
if not (quantization_scheme_supported(self.kv_b_proj) and\
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")
weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
@@ -198,18 +310,35 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK, W_UV = kv_b_proj_weight.split( W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj = self.q_proj.weight.T\ q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim) .view(-1, self.num_heads, self.qk_head_dim)
# can be W_Q or W_UQ depending q_lora_rank, the former if # can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend # q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer # perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix # to pass in the correct matrix
W_Q = q_proj[..., :self.qk_nope_head_dim] W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\ self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous() .flatten(start_dim=1).contiguous()
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape
# #
# Perform matrix-absorption following # Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551 # https://github.com/flashinfer-ai/flashinfer/pull/551
@@ -223,25 +352,44 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# latter otherwise # latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj # basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ # instead of UQ
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous() .flatten(start_dim=1).contiguous()
W_O = self.o_proj.weight\ if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim) .view(-1, self.num_heads, self.v_head_dim)
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous() .flatten(start_dim=0, end_dim=1).contiguous()
tp_size = get_tensor_model_parallel_world_size() if is_fp8(weight_dtype) and requantization_enabled:
self.o_proj_absorbed = RowParallelLinear( W_UV_O, W_UV_O_scales = scaled_quantize(
self.W_UV_O.shape[0] * tp_size, W_UV_O,
self.W_UV_O.shape[1], self.reqaunt_weight_group_shape,
bias=False, quant_dtype=current_platform_fp8_dtype)
# TODO(lucas) figure out how to properly forward quant_method # For FP8 save the transpose so we can use
#quant_config=self.o_proj.quant_method, # `apply_w8a8_block_fp8_linear` directly
) self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T) self.tp_size = get_tensor_model_parallel_world_size()
else: else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")
self.W_UV = W_UV self.W_UV = W_UV
self.W_UK = W_UK self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1) self.W_Q = W_Q.flatten(start_dim=1)

View File

@@ -57,14 +57,12 @@ class TritonMLABackend(AttentionBackend):
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
block_size: int, block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA num_kv_heads: int, # assumed to be 1 for MLA
kv_lora_rank: int, # passed via head_size head_size: int,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
# TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank return (num_blocks, block_size, head_size)
k_pe_size = kv_lora_rank // 8
return (num_blocks, block_size, kv_lora_rank + k_pe_size)
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
@@ -83,7 +81,7 @@ class TritonMLABackend(AttentionBackend):
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
return [512] return [576]
class TritonMLAState(AttentionState): class TritonMLAState(AttentionState):
@@ -624,8 +622,6 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
self.multimodal_placeholder_maps.items() self.multimodal_placeholder_maps.items()
} }
num_kv_splits = 8
return TritonMLAMetadata( return TritonMLAMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
@@ -645,7 +641,7 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
num_kv_splits=num_kv_splits, num_kv_splits=4, # TODO(lucas) add heuristic
head_dim=self.runner.model_config.get_head_size(), head_dim=self.runner.model_config.get_head_size(),
) )

View File

@@ -200,9 +200,9 @@ class Attention(nn.Module):
s += f", backend={self.impl.__class__.__name__}" s += f", backend={self.impl.__class__.__name__}"
return s return s
def process_weights_after_loading(self): def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"): if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading() self.impl.process_weights_after_loading(act_dtype)
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):

View File

@@ -739,18 +739,19 @@ class ModelConfig:
@property @property
def is_deepseek_mla(self) -> bool: def is_deepseek_mla(self) -> bool:
# TODO add deepseek_v3 # TODO add deepseek_v3
return hasattr(self.hf_text_config, return (hasattr(self.hf_text_config, "model_type")) \
"model_type") and (self.hf_text_config.model_type and (self.hf_text_config.model_type in \
in ('deepseek_v2')) ('deepseek_v2', 'deepseek_v3'))\
and (self.hf_text_config.kv_lora_rank is not None)
def get_head_size(self) -> int: def get_head_size(self) -> int:
# TODO remove hard code # TODO remove hard code
if self.is_deepseek_mla: if self.is_deepseek_mla:
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
0)
if self.use_mla: if self.use_mla:
return self.hf_text_config.kv_lora_rank return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
else: else:
qk_rope_head_dim = getattr(self.hf_text_config,
"qk_rope_head_dim", 0)
qk_nope_head_dim = getattr(self.hf_text_config, qk_nope_head_dim = getattr(self.hf_text_config,
"qk_nope_head_dim", 0) "qk_nope_head_dim", 0)
if qk_rope_head_dim and qk_nope_head_dim: if qk_rope_head_dim and qk_nope_head_dim:
@@ -969,6 +970,32 @@ class ModelConfig:
@property @property
def use_mla(self) -> bool: def use_mla(self) -> bool:
if self.quantization is not None and self.quantization not in [\
"fp8", "compressed-tensors"]:
logger.warning(
"MLA is not supported with %s quantization. "
"Disabling MLA.", self.quantization)
return False
# If using a "compressed-tensors" checkpoint, check that all groups
# have fp8 for both weights and activations.
if self.quantization == "compressed-tensors":
quant_config = self._parse_quant_hf_config()
for group_name, cfg in quant_config.get("config_groups",
("", {})).items():
act_cfg = cfg.get("input_activations", {})
act_type = None if act_cfg is None else act_cfg.get("type", "")
w_cfg = cfg.get("weights", {})
w_type = None if w_cfg is None else w_cfg.get("type", "")
if act_type != "fp8" or w_type != "fp8":
logger.warning(
"compressed-tensors MLA support requires fp8 "
"activations and weights in group '%s', but got "
"activations type '%s' and weights type '%s'.\n "
"Full config: %s", group_name, act_type, w_type,
quant_config)
return False
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE) use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
return use_mla return use_mla

View File

@@ -79,6 +79,7 @@ if TYPE_CHECKING:
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False VLLM_MLA_DISABLE: bool = False
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
def get_default_cache_root(): def get_default_cache_root():
@@ -519,7 +520,16 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage, # storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default # the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION": "VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))) lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),
# When running MLA with matrix-absorption enabled and fp8 quantized weights
# we perform the matrix-absorption in float32 precision, after the matrices
# are absorbed we requantize the weights back to fp8, this flag can be used
# to disable the requantization step, and instead convert the absorbed
# matrices to match the activation type. This can lead to higher memory and
# compute usage but better preserves the accuracy of the original model.
"VLLM_MLA_DISABLE_REQUANTIZATION":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0")))
} }
# end-env-vars-definition # end-env-vars-definition

View File

@@ -2,7 +2,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import triton import triton
@@ -10,10 +10,24 @@ import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
_normalize_quant_group_shape, scaled_dequantize)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
current_platform_fp8_dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else
torch.float8_e4m3fn)
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
if isinstance(x, torch.Tensor):
x = x.dtype
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
def apply_w8a8_block_fp8_linear( def apply_w8a8_block_fp8_linear(
input: torch.Tensor, input: torch.Tensor,
@@ -55,6 +69,42 @@ def apply_w8a8_block_fp8_linear(
return output.to(dtype=input.dtype).view(*output_shape) return output.to(dtype=input.dtype).view(*output_shape)
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
def apply_fp8_linear_generic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input = input.view(-1, input.shape[-1])
weight_group_shape = _normalize_quant_group_shape(\
weight, weight_group_shape)
input_group_shape = _normalize_quant_group_shape(input, input_group_shape)
def is_dim_blocked(dim, shape, group_shape):
return group_shape < shape[dim] and group_shape > 1
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(input, weight,
list(weight_group_shape),
weight_scale)
else:
# Despite having linear in the it doesn't conform to
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
# so we explicitly transpose the weight matrix here
return apply_fp8_linear(input, weight.T, weight_scale.T,
use_per_token_if_dynamic=\
(input_group_shape == (1, input.shape[1])))
def input_to_float8( def input_to_float8(
x: torch.Tensor, x: torch.Tensor,
dtype: Optional[torch.dtype] = None dtype: Optional[torch.dtype] = None
@@ -75,7 +125,6 @@ def input_to_float8(
def block_quant_to_tensor_quant( def block_quant_to_tensor_quant(
x_q_block: torch.Tensor, x_q_block: torch.Tensor,
x_s: torch.Tensor, x_s: torch.Tensor,
block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise """This function converts block-wise quantization to tensor-wise
quantization. The inputs are block-wise quantization tensor `x_q_block`, quantization. The inputs are block-wise quantization tensor `x_q_block`,
@@ -83,26 +132,7 @@ def block_quant_to_tensor_quant(
The outputs are tensor-wise quantization tensor and tensor-wise The outputs are tensor-wise quantization tensor and tensor-wise
quantization scale. Note only float8 is supported for now. quantization scale. Note only float8 is supported for now.
""" """
block_n, block_k = block_size[0], block_size[1] x_dq_block = scaled_dequantize(x_q_block, x_s)
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
x_dq_block_tiles = [[
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
] for i in range(k_tiles)
] for j in range(n_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
return x_q_tensor, scale return x_q_tensor, scale

View File

@@ -1,5 +1,5 @@
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from typing import List, Optional from typing import List, Optional, Tuple
import numpy import numpy
import torch import torch
@@ -20,6 +20,120 @@ FUSED_LAYER_NAME_MAPPING = {
} }
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
int]):
# -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
# Useful when treating N-dimensional group scaling as extended numpy-style
# broadcasting in numpy simply stretches dimensions with an extent of 1 to match
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
t = t.unsqueeze(i + 1)\
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
.flatten(i, i + 1)
return t
# Quantize assuming once scale per group of elements with shape group_shape,
# example group shapes:
# * (-1, -1) for per-tensor quantization
# * (1, -1) for per-row quantization
# * (-1, 1) for per-column quantization
# * (128, 128) for 128x128 deepseek style block quantization
# * (1, 128) for deepseek style activation quantization
# (i.e. per-token-per-group)
def scaled_quantize(
x: torch.Tensor,
group_shape: Tuple[int, int],
quant_dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, \
"currently `scaled_quantize` only supports floating point dtypes " \
"but could be extended to support other dtypes"
finfo = torch.finfo(quant_dtype)
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert x.ndim == 2
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
# Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
x_blkd_permd = x_blkd_permd.flatten(start_dim=2)
# Compute scales
min_val, max_val = x_blkd_permd.aminmax(dim=-1)
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
# Apply scale and convert form:
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\
.clamp(min=finfo.min, max=finfo.max)\
.reshape(blk_m, blk_n, group_shape[0], group_shape[1])\
.permute(0, 2, 1, 3)\
.reshape(x.shape)
return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal()
# inverses `scaled_quantize`
def scaled_dequantize(
x_q: torch.Tensor,
x_s: torch.Tensor,
group_shape: Optional[Tuple[int, int]] = None,
out_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
if x_s.ndim == 0: # scalar
x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor
if x_s.ndim == 1:
if group_shape is None:
raise AssertionError(
"if x_s is 1D tensor, group_shape must be provided otherwise "
"its ambiguous which dimension to broadcast x_s to")
# unsqueeze the scales for the dimension where we want to broadcast
# across the full extent
if group_shape[0] == x_q.shape[-2]:
x_s = x_s.unsqueeze(-2)
elif group_shape[1] == x_q.shape[-1]:
x_s = x_s.unsqueeze(-1)
else:
raise AssertionError(
"if x_s is a vector we should be broadcasting it to the full "
"extent of one of the dimensions")
if group_shape is not None:
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]
x_s = group_broadcast(x_s.to(torch.float32), x_q.shape)
return (x_q.to(torch.float32) * x_s).to(out_dtype)
def pack_quantized_values_into_int32(w_q: torch.Tensor, def pack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType, wtype: ScalarType,
packed_dim: int = 0): packed_dim: int = 0):

View File

@@ -398,11 +398,13 @@ class DefaultModelLoader(BaseModelLoader):
# parameters onto device for processing and back off after. # parameters onto device for processing and back off after.
with device_loading_context(module, target_device): with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
elif isinstance(module, Attention) and \ if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"): hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after # When attention modules need to process weights after
# currently only used by MLA # currently only used by MLA
module.process_weights_after_loading() # TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
return model.eval() return model.eval()
@@ -439,6 +441,11 @@ class DummyModelLoader(BaseModelLoader):
with device_loading_context( with device_loading_context(
module, torch.device(device_config.device)): module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
return model.eval() return model.eval()
@@ -633,6 +640,12 @@ class ShardedStateLoader(BaseModelLoader):
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if quant_method is not None: if quant_method is not None:
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(
model_config.dtype)
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
pattern = os.path.join( pattern = os.path.join(
local_model_path, local_model_path,
@@ -1272,7 +1285,7 @@ class GGUFModelLoader(BaseModelLoader):
class RunaiModelStreamerLoader(BaseModelLoader): class RunaiModelStreamerLoader(BaseModelLoader):
""" """
Model loader that can load safetensors Model loader that can load safetensors
files from local FS or S3 bucket. files from local FS or S3 bucket.
""" """
@@ -1369,6 +1382,11 @@ class RunaiModelStreamerLoader(BaseModelLoader):
if quant_method is not None: if quant_method is not None:
with device_loading_context(module, target_device): with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
return model.eval() return model.eval()

View File

@@ -27,7 +27,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@@ -333,12 +333,156 @@ class DeepseekV3Attention(nn.Module):
return output return output
class DeepseekV3MLAAttention(nn.Module):
"""
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
"""
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
attn_metadata)
class DeepseekV3DecoderLayer(nn.Module): class DeepseekV3DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
prefix: str, prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
@@ -351,7 +495,11 @@ class DeepseekV3DecoderLayer(nn.Module):
# DecoderLayers are created with `make_layers` which passes the prefix # DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index. # with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1]) layer_idx = int(prefix.split(sep='.')[-1])
self.self_attn = DeepseekV3Attention( if model_config.use_mla:
attn_cls = DeepseekV3MLAAttention
else:
attn_cls = DeepseekV3Attention
self.self_attn = attn_cls(
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
@@ -428,6 +576,7 @@ class DeepseekV3Model(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
@@ -447,6 +596,7 @@ class DeepseekV3Model(nn.Module):
lambda prefix: DeepseekV3DecoderLayer( lambda prefix: DeepseekV3DecoderLayer(
config, config,
prefix, prefix,
model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
), ),

View File

@@ -110,7 +110,9 @@ class CacheEngine:
parallel_config, LayerBlockType.attention) parallel_config, LayerBlockType.attention)
key_cache_block = cache_config.block_size * num_heads * head_size key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block # For MLA there is no value cache, since the latent vector
# is joint keys and values.
value_cache_block = key_cache_block if not model_config.use_mla else 0
total = num_attention_layers * (key_cache_block + value_cache_block) total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto": if cache_config.cache_dtype == "auto":
dtype = model_config.dtype dtype = model_config.dtype