[Attention] Deepseek v3 MLA support with FP8 compute (#12601)
This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: simon-mo <simon.mo@hey.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
This commit is contained in:
@@ -1,17 +1,29 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Generic, List, Optional
|
from typing import Any, Dict, Generic, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
MLAAttentionImpl, T)
|
MLAAttentionImpl, T)
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
LinearBase, RowParallelLinear,
|
||||||
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
|
CompressedTensorsLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsW8A8Fp8)
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
scaled_dequantize, scaled_quantize)
|
||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
@@ -162,8 +174,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
def _v_up_proj_and_o_proj(self, x):
|
def _v_up_proj_and_o_proj(self, x):
|
||||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||||
return self.o_proj_absorbed(
|
if is_fp8(self.W_UV_O):
|
||||||
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
|
output_parallel = apply_fp8_linear_generic(
|
||||||
|
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
||||||
|
self.reqaunt_input_group_shape,
|
||||||
|
self.reqaunt_weight_group_shape)
|
||||||
|
else:
|
||||||
|
output_parallel = torch.matmul(x.flatten(start_dim=1),
|
||||||
|
self.W_UV_O)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
return output
|
||||||
else:
|
else:
|
||||||
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||||
return self.o_proj(x.reshape(-1,
|
return self.o_proj(x.reshape(-1,
|
||||||
@@ -171,6 +194,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
def _q_proj_and_k_up_proj(self, x):
|
def _q_proj_and_k_up_proj(self, x):
|
||||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||||
|
if is_fp8(self.W_Q_UK):
|
||||||
|
return apply_fp8_linear_generic(
|
||||||
|
x, self.W_Q_UK, self.W_Q_UK_scales,
|
||||||
|
self.reqaunt_input_group_shape,
|
||||||
|
self.reqaunt_weight_group_shape).view(
|
||||||
|
-1, self.num_heads, self.kv_lora_rank)
|
||||||
return torch.matmul(x, self.W_Q_UK)\
|
return torch.matmul(x, self.W_Q_UK)\
|
||||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||||
else:
|
else:
|
||||||
@@ -179,8 +208,91 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
||||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
def process_weights_after_loading(self):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
kv_b_proj_weight = self.kv_b_proj.weight.T
|
|
||||||
|
def is_layer_fp8(layer: LinearBase) -> bool:
|
||||||
|
return isinstance(layer.quant_method, Fp8LinearMethod) or\
|
||||||
|
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||||
|
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
||||||
|
|
||||||
|
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
||||||
|
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
|
||||||
|
is_layer_fp8(layer)
|
||||||
|
|
||||||
|
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||||
|
# all the FP8 code with a more standard way of
|
||||||
|
# defining schemes/group-shapes, we should also potentially force
|
||||||
|
# quant_methods to support a decompress function
|
||||||
|
#
|
||||||
|
# returns input_group_shape, weight_group_shape
|
||||||
|
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||||
|
Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||||
|
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||||
|
if layer.quant_method.block_quant is not None:
|
||||||
|
weight_block_size = \
|
||||||
|
layer.quant_method.quant_config.weight_block_size
|
||||||
|
# per-token-group (1, X), block-quantized (X, Y)
|
||||||
|
return (1, weight_block_size[-1]), weight_block_size
|
||||||
|
else:
|
||||||
|
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||||
|
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||||
|
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||||
|
# this is hacky but we always assume the for
|
||||||
|
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||||
|
# we ignore if it is static-per-tensor since we are going to
|
||||||
|
# requantize after later anyways
|
||||||
|
strategy = layer.scheme.strategy
|
||||||
|
if strategy == QuantizationStrategy.TENSOR:
|
||||||
|
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||||
|
elif strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
return (1, -1), (-1, 1) # per-token, per-channel
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"QuantizationStrategy.{strategy} is not supported for "
|
||||||
|
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Can't determine scale group shapes for "
|
||||||
|
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scales(layer: LinearBase) -> torch.Tensor:
|
||||||
|
if hasattr(layer, "weight_scale_inv"):
|
||||||
|
return layer.weight_scale_inv
|
||||||
|
return layer.weight_scale
|
||||||
|
|
||||||
|
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||||
|
if is_layer_fp8(layer):
|
||||||
|
if isinstance(layer.quant_method, \
|
||||||
|
CompressedTensorsLinearMethod) and \
|
||||||
|
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||||
|
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||||
|
# seems to store weights as (input, output) instead of
|
||||||
|
# (output, input) so we need to transpose
|
||||||
|
weight = layer.weight.T # standardize to (output, input)
|
||||||
|
else:
|
||||||
|
weight = layer.weight
|
||||||
|
_, weight_scale_group_shape = \
|
||||||
|
get_scale_group_shapes_for_fp8(layer)
|
||||||
|
scales = get_scales(layer)
|
||||||
|
|
||||||
|
return scaled_dequantize(weight, scales,
|
||||||
|
weight_scale_group_shape)
|
||||||
|
else:
|
||||||
|
return layer.weight
|
||||||
|
|
||||||
|
if not (quantization_scheme_supported(self.kv_b_proj) and\
|
||||||
|
quantization_scheme_supported(self.q_proj) and\
|
||||||
|
quantization_scheme_supported(self.o_proj)):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
|
||||||
|
", please run with VLLM_MLA_DISABLE=1")
|
||||||
|
|
||||||
|
weight_dtype = self.kv_b_proj.weight.dtype
|
||||||
|
assert self.o_proj.weight.dtype == weight_dtype
|
||||||
|
assert self.q_proj.weight.dtype == weight_dtype
|
||||||
|
|
||||||
|
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||||
assert kv_b_proj_weight.shape == (
|
assert kv_b_proj_weight.shape == (
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||||
@@ -198,18 +310,35 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
W_UK, W_UV = kv_b_proj_weight.split(
|
W_UK, W_UV = kv_b_proj_weight.split(
|
||||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
q_proj = self.q_proj.weight.T\
|
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||||
.view(-1, self.num_heads, self.qk_head_dim)
|
.view(-1, self.num_heads, self.qk_head_dim)
|
||||||
|
|
||||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||||
# perspective though we call these both W_Q and rely on the layer
|
# perspective though we call these both W_Q and rely on the layer
|
||||||
# to pass in the correct matrix
|
# to pass in the correct matrix
|
||||||
W_Q = q_proj[..., :self.qk_nope_head_dim]
|
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||||
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
|
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||||
.flatten(start_dim=1).contiguous()
|
.flatten(start_dim=1).contiguous()
|
||||||
|
|
||||||
|
# W_QR is small so for simplicity we dont bother requantizing it
|
||||||
|
self.W_QR = self.W_QR.to(act_dtype)
|
||||||
|
|
||||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||||
|
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||||
|
if is_fp8(weight_dtype) and requantization_enabled:
|
||||||
|
# This assumes it wise to requantize using the same group shapes
|
||||||
|
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||||
|
# weights were originally quantized
|
||||||
|
requant_input_group_shape, requant_weight_group_shape = \
|
||||||
|
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||||
|
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||||
|
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||||
|
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||||
|
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||||
|
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||||
|
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||||
|
|
||||||
#
|
#
|
||||||
# Perform matrix-absorption following
|
# Perform matrix-absorption following
|
||||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||||
@@ -223,25 +352,44 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
# latter otherwise
|
# latter otherwise
|
||||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||||
# instead of UQ
|
# instead of UQ
|
||||||
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||||
.flatten(start_dim=1).contiguous()
|
.flatten(start_dim=1).contiguous()
|
||||||
|
|
||||||
W_O = self.o_proj.weight\
|
if is_fp8(weight_dtype) and requantization_enabled:
|
||||||
|
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||||
|
W_Q_UK,
|
||||||
|
self.reqaunt_weight_group_shape,
|
||||||
|
quant_dtype=current_platform_fp8_dtype)
|
||||||
|
# For FP8 save the transpose so we can use
|
||||||
|
# `apply_w8a8_block_fp8_linear` directly
|
||||||
|
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||||
|
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||||
|
else:
|
||||||
|
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||||
|
|
||||||
|
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||||
.view(-1, self.num_heads, self.v_head_dim)
|
.view(-1, self.num_heads, self.v_head_dim)
|
||||||
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
if is_fp8(weight_dtype) and requantization_enabled:
|
||||||
self.o_proj_absorbed = RowParallelLinear(
|
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||||
self.W_UV_O.shape[0] * tp_size,
|
W_UV_O,
|
||||||
self.W_UV_O.shape[1],
|
self.reqaunt_weight_group_shape,
|
||||||
bias=False,
|
quant_dtype=current_platform_fp8_dtype)
|
||||||
# TODO(lucas) figure out how to properly forward quant_method
|
# For FP8 save the transpose so we can use
|
||||||
#quant_config=self.o_proj.quant_method,
|
# `apply_w8a8_block_fp8_linear` directly
|
||||||
)
|
self.W_UV_O = W_UV_O.T.contiguous()
|
||||||
|
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||||
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
|
|
||||||
else:
|
else:
|
||||||
|
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||||
|
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
else:
|
||||||
|
if is_fp8(weight_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently fp8 requires matrix absorption")
|
||||||
|
|
||||||
self.W_UV = W_UV
|
self.W_UV = W_UV
|
||||||
self.W_UK = W_UK
|
self.W_UK = W_UK
|
||||||
self.W_Q = W_Q.flatten(start_dim=1)
|
self.W_Q = W_Q.flatten(start_dim=1)
|
||||||
|
|||||||
@@ -60,11 +60,9 @@ class TritonMLABackend(AttentionBackend):
|
|||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
num_kv_heads: int, # assumed to be 1 for MLA
|
num_kv_heads: int, # assumed to be 1 for MLA
|
||||||
kv_lora_rank: int, # passed via head_size
|
head_size: int,
|
||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
# TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank
|
return (num_blocks, block_size, head_size)
|
||||||
k_pe_size = kv_lora_rank // 8
|
|
||||||
return (num_blocks, block_size, kv_lora_rank + k_pe_size)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
@@ -83,7 +81,7 @@ class TritonMLABackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_head_sizes() -> List[int]:
|
def get_supported_head_sizes() -> List[int]:
|
||||||
return [512]
|
return [576]
|
||||||
|
|
||||||
|
|
||||||
class TritonMLAState(AttentionState):
|
class TritonMLAState(AttentionState):
|
||||||
@@ -624,8 +622,6 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
|||||||
self.multimodal_placeholder_maps.items()
|
self.multimodal_placeholder_maps.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
num_kv_splits = 8
|
|
||||||
|
|
||||||
return TritonMLAMetadata(
|
return TritonMLAMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
@@ -645,7 +641,7 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
|||||||
context_lens_tensor=context_lens_tensor,
|
context_lens_tensor=context_lens_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph,
|
||||||
num_kv_splits=num_kv_splits,
|
num_kv_splits=4, # TODO(lucas) add heuristic
|
||||||
head_dim=self.runner.model_config.get_head_size(),
|
head_dim=self.runner.model_config.get_head_size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -200,9 +200,9 @@ class Attention(nn.Module):
|
|||||||
s += f", backend={self.impl.__class__.__name__}"
|
s += f", backend={self.impl.__class__.__name__}"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def process_weights_after_loading(self):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
if hasattr(self.impl, "process_weights_after_loading"):
|
if hasattr(self.impl, "process_weights_after_loading"):
|
||||||
self.impl.process_weights_after_loading()
|
self.impl.process_weights_after_loading(act_dtype)
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
|
|||||||
@@ -739,18 +739,19 @@ class ModelConfig:
|
|||||||
@property
|
@property
|
||||||
def is_deepseek_mla(self) -> bool:
|
def is_deepseek_mla(self) -> bool:
|
||||||
# TODO add deepseek_v3
|
# TODO add deepseek_v3
|
||||||
return hasattr(self.hf_text_config,
|
return (hasattr(self.hf_text_config, "model_type")) \
|
||||||
"model_type") and (self.hf_text_config.model_type
|
and (self.hf_text_config.model_type in \
|
||||||
in ('deepseek_v2'))
|
('deepseek_v2', 'deepseek_v3'))\
|
||||||
|
and (self.hf_text_config.kv_lora_rank is not None)
|
||||||
|
|
||||||
def get_head_size(self) -> int:
|
def get_head_size(self) -> int:
|
||||||
# TODO remove hard code
|
# TODO remove hard code
|
||||||
if self.is_deepseek_mla:
|
if self.is_deepseek_mla:
|
||||||
|
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
|
||||||
|
0)
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
return self.hf_text_config.kv_lora_rank
|
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
|
||||||
else:
|
else:
|
||||||
qk_rope_head_dim = getattr(self.hf_text_config,
|
|
||||||
"qk_rope_head_dim", 0)
|
|
||||||
qk_nope_head_dim = getattr(self.hf_text_config,
|
qk_nope_head_dim = getattr(self.hf_text_config,
|
||||||
"qk_nope_head_dim", 0)
|
"qk_nope_head_dim", 0)
|
||||||
if qk_rope_head_dim and qk_nope_head_dim:
|
if qk_rope_head_dim and qk_nope_head_dim:
|
||||||
@@ -969,6 +970,32 @@ class ModelConfig:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def use_mla(self) -> bool:
|
def use_mla(self) -> bool:
|
||||||
|
if self.quantization is not None and self.quantization not in [\
|
||||||
|
"fp8", "compressed-tensors"]:
|
||||||
|
logger.warning(
|
||||||
|
"MLA is not supported with %s quantization. "
|
||||||
|
"Disabling MLA.", self.quantization)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If using a "compressed-tensors" checkpoint, check that all groups
|
||||||
|
# have fp8 for both weights and activations.
|
||||||
|
if self.quantization == "compressed-tensors":
|
||||||
|
quant_config = self._parse_quant_hf_config()
|
||||||
|
for group_name, cfg in quant_config.get("config_groups",
|
||||||
|
("", {})).items():
|
||||||
|
act_cfg = cfg.get("input_activations", {})
|
||||||
|
act_type = None if act_cfg is None else act_cfg.get("type", "")
|
||||||
|
w_cfg = cfg.get("weights", {})
|
||||||
|
w_type = None if w_cfg is None else w_cfg.get("type", "")
|
||||||
|
if act_type != "fp8" or w_type != "fp8":
|
||||||
|
logger.warning(
|
||||||
|
"compressed-tensors MLA support requires fp8 "
|
||||||
|
"activations and weights in group '%s', but got "
|
||||||
|
"activations type '%s' and weights type '%s'.\n "
|
||||||
|
"Full config: %s", group_name, act_type, w_type,
|
||||||
|
quant_config)
|
||||||
|
return False
|
||||||
|
|
||||||
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
|
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
|
||||||
return use_mla
|
return use_mla
|
||||||
|
|
||||||
|
|||||||
12
vllm/envs.py
12
vllm/envs.py
@@ -79,6 +79,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||||
VLLM_MLA_DISABLE: bool = False
|
VLLM_MLA_DISABLE: bool = False
|
||||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||||
|
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@@ -519,7 +520,16 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
|
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
|
||||||
# the is enabled by default
|
# the is enabled by default
|
||||||
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
|
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
|
||||||
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
|
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),
|
||||||
|
|
||||||
|
# When running MLA with matrix-absorption enabled and fp8 quantized weights
|
||||||
|
# we perform the matrix-absorption in float32 precision, after the matrices
|
||||||
|
# are absorbed we requantize the weights back to fp8, this flag can be used
|
||||||
|
# to disable the requantization step, and instead convert the absorbed
|
||||||
|
# matrices to match the activation type. This can lead to higher memory and
|
||||||
|
# compute usage but better preserves the accuracy of the original model.
|
||||||
|
"VLLM_MLA_DISABLE_REQUANTIZATION":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0")))
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -10,10 +10,24 @@ import triton.language as tl
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
_normalize_quant_group_shape, scaled_dequantize)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
apply_fp8_linear)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
current_platform_fp8_dtype = (torch.float8_e4m3fnuz
|
||||||
|
if current_platform.is_rocm() else
|
||||||
|
torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
x = x.dtype
|
||||||
|
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||||
|
|
||||||
|
|
||||||
def apply_w8a8_block_fp8_linear(
|
def apply_w8a8_block_fp8_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
@@ -55,6 +69,42 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
return output.to(dtype=input.dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
# Unify the interface between `apply_w8a8_block_fp8_linear` and
|
||||||
|
# `apply_fp8_linear`
|
||||||
|
# NOTE(lucas): this is quite messy, we should think through this more formally
|
||||||
|
def apply_fp8_linear_generic(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_group_shape: Tuple[int, int],
|
||||||
|
weight_group_shape: Tuple[int, int],
|
||||||
|
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
input = input.view(-1, input.shape[-1])
|
||||||
|
|
||||||
|
weight_group_shape = _normalize_quant_group_shape(\
|
||||||
|
weight, weight_group_shape)
|
||||||
|
input_group_shape = _normalize_quant_group_shape(input, input_group_shape)
|
||||||
|
|
||||||
|
def is_dim_blocked(dim, shape, group_shape):
|
||||||
|
return group_shape < shape[dim] and group_shape > 1
|
||||||
|
|
||||||
|
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
|
||||||
|
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
|
||||||
|
input_group_shape == (1, weight_group_shape[1]):
|
||||||
|
return apply_w8a8_block_fp8_linear(input, weight,
|
||||||
|
list(weight_group_shape),
|
||||||
|
weight_scale)
|
||||||
|
else:
|
||||||
|
# Despite having linear in the it doesn't conform to
|
||||||
|
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
|
||||||
|
# so we explicitly transpose the weight matrix here
|
||||||
|
return apply_fp8_linear(input, weight.T, weight_scale.T,
|
||||||
|
use_per_token_if_dynamic=\
|
||||||
|
(input_group_shape == (1, input.shape[1])))
|
||||||
|
|
||||||
|
|
||||||
def input_to_float8(
|
def input_to_float8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
dtype: Optional[torch.dtype] = None
|
dtype: Optional[torch.dtype] = None
|
||||||
@@ -75,7 +125,6 @@ def input_to_float8(
|
|||||||
def block_quant_to_tensor_quant(
|
def block_quant_to_tensor_quant(
|
||||||
x_q_block: torch.Tensor,
|
x_q_block: torch.Tensor,
|
||||||
x_s: torch.Tensor,
|
x_s: torch.Tensor,
|
||||||
block_size: List[int],
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""This function converts block-wise quantization to tensor-wise
|
"""This function converts block-wise quantization to tensor-wise
|
||||||
quantization. The inputs are block-wise quantization tensor `x_q_block`,
|
quantization. The inputs are block-wise quantization tensor `x_q_block`,
|
||||||
@@ -83,26 +132,7 @@ def block_quant_to_tensor_quant(
|
|||||||
The outputs are tensor-wise quantization tensor and tensor-wise
|
The outputs are tensor-wise quantization tensor and tensor-wise
|
||||||
quantization scale. Note only float8 is supported for now.
|
quantization scale. Note only float8 is supported for now.
|
||||||
"""
|
"""
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
x_dq_block = scaled_dequantize(x_q_block, x_s)
|
||||||
n, k = x_q_block.shape
|
|
||||||
n_tiles = (n + block_n - 1) // block_n
|
|
||||||
k_tiles = (k + block_k - 1) // block_k
|
|
||||||
assert n_tiles == x_s.shape[0]
|
|
||||||
assert k_tiles == x_s.shape[1]
|
|
||||||
|
|
||||||
x_dq_block = x_q_block.to(torch.float32)
|
|
||||||
|
|
||||||
x_dq_block_tiles = [[
|
|
||||||
x_dq_block[
|
|
||||||
j * block_n:min((j + 1) * block_n, n),
|
|
||||||
i * block_k:min((i + 1) * block_k, k),
|
|
||||||
] for i in range(k_tiles)
|
|
||||||
] for j in range(n_tiles)]
|
|
||||||
|
|
||||||
for i in range(k_tiles):
|
|
||||||
for j in range(n_tiles):
|
|
||||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
|
||||||
|
|
||||||
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||||
return x_q_tensor, scale
|
return x_q_tensor, scale
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""This file is used for /tests and /benchmarks"""
|
"""This file is used for /tests and /benchmarks"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
@@ -20,6 +20,120 @@ FUSED_LAYER_NAME_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Normalize the group_shape to the full extent for any dims that are -1
|
||||||
|
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
|
||||||
|
int]):
|
||||||
|
# -1 means full extent
|
||||||
|
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||||
|
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
# Useful when treating N-dimensional group scaling as extended numpy-style
|
||||||
|
# broadcasting in numpy simply stretches dimensions with an extent of 1 to match
|
||||||
|
# the target shape by repeating the data along that dimension (broadcasting)
|
||||||
|
# , we extend these semantics to say if the extent of a dimension in the
|
||||||
|
# source shape is not 1 and does not match the target shape we repeat each
|
||||||
|
# element along that dimension src_shape[dim] // target_shape[dim] times
|
||||||
|
# example if we have:
|
||||||
|
# a = [[1, 2], and target_shape = (2, 4)
|
||||||
|
# [3, 4]]
|
||||||
|
# then we would expand a to:
|
||||||
|
# a = [[1, 1, 2, 2],
|
||||||
|
# [3, 3, 4, 4]]
|
||||||
|
# NOTE this function this function does not explicitly broadcast dimensions
|
||||||
|
# with an extent of 1, since this can be done implicitly by pytorch
|
||||||
|
def group_broadcast(t, shape):
|
||||||
|
for i, s in enumerate(shape):
|
||||||
|
if t.shape[i] != s and t.shape[i] != 1:
|
||||||
|
assert s % t.shape[i] == 0
|
||||||
|
t = t.unsqueeze(i + 1)\
|
||||||
|
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
|
||||||
|
.flatten(i, i + 1)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
# Quantize assuming once scale per group of elements with shape group_shape,
|
||||||
|
# example group shapes:
|
||||||
|
# * (-1, -1) for per-tensor quantization
|
||||||
|
# * (1, -1) for per-row quantization
|
||||||
|
# * (-1, 1) for per-column quantization
|
||||||
|
# * (128, 128) for 128x128 deepseek style block quantization
|
||||||
|
# * (1, 128) for deepseek style activation quantization
|
||||||
|
# (i.e. per-token-per-group)
|
||||||
|
def scaled_quantize(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_shape: Tuple[int, int],
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||||
|
assert quant_dtype.is_floating_point, \
|
||||||
|
"currently `scaled_quantize` only supports floating point dtypes " \
|
||||||
|
"but could be extended to support other dtypes"
|
||||||
|
|
||||||
|
finfo = torch.finfo(quant_dtype)
|
||||||
|
|
||||||
|
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
|
||||||
|
assert x.ndim == 2
|
||||||
|
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
|
||||||
|
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
|
||||||
|
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
|
||||||
|
|
||||||
|
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||||
|
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
|
||||||
|
# Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
|
||||||
|
x_blkd_permd = x_blkd_permd.flatten(start_dim=2)
|
||||||
|
|
||||||
|
# Compute scales
|
||||||
|
min_val, max_val = x_blkd_permd.aminmax(dim=-1)
|
||||||
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
|
scale = finfo.max / amax
|
||||||
|
|
||||||
|
# Apply scale and convert form:
|
||||||
|
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
|
||||||
|
x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\
|
||||||
|
.clamp(min=finfo.min, max=finfo.max)\
|
||||||
|
.reshape(blk_m, blk_n, group_shape[0], group_shape[1])\
|
||||||
|
.permute(0, 2, 1, 3)\
|
||||||
|
.reshape(x.shape)
|
||||||
|
|
||||||
|
return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
# inverses `scaled_quantize`
|
||||||
|
def scaled_dequantize(
|
||||||
|
x_q: torch.Tensor,
|
||||||
|
x_s: torch.Tensor,
|
||||||
|
group_shape: Optional[Tuple[int, int]] = None,
|
||||||
|
out_dtype: torch.dtype = torch.float32,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if group_shape is not None:
|
||||||
|
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||||
|
|
||||||
|
if x_s.ndim == 0: # scalar
|
||||||
|
x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor
|
||||||
|
if x_s.ndim == 1:
|
||||||
|
if group_shape is None:
|
||||||
|
raise AssertionError(
|
||||||
|
"if x_s is 1D tensor, group_shape must be provided otherwise "
|
||||||
|
"its ambiguous which dimension to broadcast x_s to")
|
||||||
|
# unsqueeze the scales for the dimension where we want to broadcast
|
||||||
|
# across the full extent
|
||||||
|
if group_shape[0] == x_q.shape[-2]:
|
||||||
|
x_s = x_s.unsqueeze(-2)
|
||||||
|
elif group_shape[1] == x_q.shape[-1]:
|
||||||
|
x_s = x_s.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
raise AssertionError(
|
||||||
|
"if x_s is a vector we should be broadcasting it to the full "
|
||||||
|
"extent of one of the dimensions")
|
||||||
|
|
||||||
|
if group_shape is not None:
|
||||||
|
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
|
||||||
|
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]
|
||||||
|
x_s = group_broadcast(x_s.to(torch.float32), x_q.shape)
|
||||||
|
return (x_q.to(torch.float32) * x_s).to(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
def pack_quantized_values_into_int32(w_q: torch.Tensor,
|
def pack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||||
wtype: ScalarType,
|
wtype: ScalarType,
|
||||||
packed_dim: int = 0):
|
packed_dim: int = 0):
|
||||||
|
|||||||
@@ -398,11 +398,13 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
# parameters onto device for processing and back off after.
|
# parameters onto device for processing and back off after.
|
||||||
with device_loading_context(module, target_device):
|
with device_loading_context(module, target_device):
|
||||||
quant_method.process_weights_after_loading(module)
|
quant_method.process_weights_after_loading(module)
|
||||||
elif isinstance(module, Attention) and \
|
if isinstance(module, Attention) and \
|
||||||
hasattr(module, "process_weights_after_loading"):
|
hasattr(module, "process_weights_after_loading"):
|
||||||
# When attention modules need to process weights after
|
# When attention modules need to process weights after
|
||||||
# currently only used by MLA
|
# currently only used by MLA
|
||||||
module.process_weights_after_loading()
|
# TODO(lucas): see if there is a way to unify the signatures
|
||||||
|
# of process_weights_after_loading
|
||||||
|
module.process_weights_after_loading(model_config.dtype)
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
@@ -439,6 +441,11 @@ class DummyModelLoader(BaseModelLoader):
|
|||||||
with device_loading_context(
|
with device_loading_context(
|
||||||
module, torch.device(device_config.device)):
|
module, torch.device(device_config.device)):
|
||||||
quant_method.process_weights_after_loading(module)
|
quant_method.process_weights_after_loading(module)
|
||||||
|
if isinstance(module, Attention) and \
|
||||||
|
hasattr(module, "process_weights_after_loading"):
|
||||||
|
# When attention modules need to process weights after
|
||||||
|
# currently only used by MLA
|
||||||
|
module.process_weights_after_loading(model_config.dtype)
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
@@ -633,6 +640,12 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
quant_method = getattr(module, "quant_method", None)
|
quant_method = getattr(module, "quant_method", None)
|
||||||
if quant_method is not None:
|
if quant_method is not None:
|
||||||
quant_method.process_weights_after_loading(module)
|
quant_method.process_weights_after_loading(module)
|
||||||
|
if isinstance(module, Attention) and \
|
||||||
|
hasattr(module, "process_weights_after_loading"):
|
||||||
|
# When attention modules need to process weights after
|
||||||
|
# currently only used by MLA
|
||||||
|
module.process_weights_after_loading(
|
||||||
|
model_config.dtype)
|
||||||
rank = get_tensor_model_parallel_rank()
|
rank = get_tensor_model_parallel_rank()
|
||||||
pattern = os.path.join(
|
pattern = os.path.join(
|
||||||
local_model_path,
|
local_model_path,
|
||||||
@@ -1369,6 +1382,11 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
|||||||
if quant_method is not None:
|
if quant_method is not None:
|
||||||
with device_loading_context(module, target_device):
|
with device_loading_context(module, target_device):
|
||||||
quant_method.process_weights_after_loading(module)
|
quant_method.process_weights_after_loading(module)
|
||||||
|
if isinstance(module, Attention) and \
|
||||||
|
hasattr(module, "process_weights_after_loading"):
|
||||||
|
# When attention modules need to process weights after
|
||||||
|
# currently only used by MLA
|
||||||
|
module.process_weights_after_loading(model_config.dtype)
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import (get_pp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@@ -333,12 +333,156 @@ class DeepseekV3Attention(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MLAAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||||
|
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||||
|
|
||||||
|
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
v_head_dim: int,
|
||||||
|
q_lora_rank: Optional[int],
|
||||||
|
kv_lora_rank: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
assert num_heads % tp_size == 0
|
||||||
|
self.num_local_heads = num_heads // tp_size
|
||||||
|
|
||||||
|
self.scaling = self.qk_head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
|
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||||
|
self.q_lora_rank,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_a_proj")
|
||||||
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||||
|
self.num_heads *
|
||||||
|
self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_b_proj")
|
||||||
|
else:
|
||||||
|
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||||
|
self.num_heads *
|
||||||
|
self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_proj")
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||||
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj")
|
||||||
|
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj")
|
||||||
|
|
||||||
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||||
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||||
|
rotary_dim=qk_rope_head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=False)
|
||||||
|
if rope_scaling:
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
|
||||||
|
self.mla_attn = Attention(
|
||||||
|
num_heads=self.num_local_heads,
|
||||||
|
head_size=self.kv_lora_rank,
|
||||||
|
scale=self.scaling,
|
||||||
|
num_kv_heads=1,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
use_mla=True,
|
||||||
|
# MLA Args
|
||||||
|
q_lora_rank=self.q_lora_rank,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
qk_head_dim=self.qk_head_dim,
|
||||||
|
v_head_dim=self.v_head_dim,
|
||||||
|
rotary_emb=self.rotary_emb,
|
||||||
|
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||||
|
kv_b_proj=self.kv_b_proj,
|
||||||
|
o_proj=self.o_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prefix = prefix
|
||||||
|
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
|
ckq = self.q_a_proj(hidden_states)[0]
|
||||||
|
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||||
|
else:
|
||||||
|
hidden_states_or_q_c = hidden_states
|
||||||
|
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||||
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
|
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||||
|
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
||||||
|
attn_metadata)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3DecoderLayer(nn.Module):
|
class DeepseekV3DecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
|
model_config: ModelConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -351,7 +495,11 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||||
# with the layer's index.
|
# with the layer's index.
|
||||||
layer_idx = int(prefix.split(sep='.')[-1])
|
layer_idx = int(prefix.split(sep='.')[-1])
|
||||||
self.self_attn = DeepseekV3Attention(
|
if model_config.use_mla:
|
||||||
|
attn_cls = DeepseekV3MLAAttention
|
||||||
|
else:
|
||||||
|
attn_cls = DeepseekV3Attention
|
||||||
|
self.self_attn = attn_cls(
|
||||||
config=config,
|
config=config,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
@@ -428,6 +576,7 @@ class DeepseekV3Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
@@ -447,6 +596,7 @@ class DeepseekV3Model(nn.Module):
|
|||||||
lambda prefix: DeepseekV3DecoderLayer(
|
lambda prefix: DeepseekV3DecoderLayer(
|
||||||
config,
|
config,
|
||||||
prefix,
|
prefix,
|
||||||
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -110,7 +110,9 @@ class CacheEngine:
|
|||||||
parallel_config, LayerBlockType.attention)
|
parallel_config, LayerBlockType.attention)
|
||||||
|
|
||||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||||
value_cache_block = key_cache_block
|
# For MLA there is no value cache, since the latent vector
|
||||||
|
# is joint keys and values.
|
||||||
|
value_cache_block = key_cache_block if not model_config.use_mla else 0
|
||||||
total = num_attention_layers * (key_cache_block + value_cache_block)
|
total = num_attention_layers * (key_cache_block + value_cache_block)
|
||||||
if cache_config.cache_dtype == "auto":
|
if cache_config.cache_dtype == "auto":
|
||||||
dtype = model_config.dtype
|
dtype = model_config.dtype
|
||||||
|
|||||||
Reference in New Issue
Block a user