[Attention] MLA get rid of materialization (#14770)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -7,22 +7,22 @@ First we define:
|
||||
Sq as Q sequence length
|
||||
Skv as KV sequence length
|
||||
|
||||
MLA has two possible ways of computing, a data-movement friendly approach and a
|
||||
compute friendly approach, we generally want to use the compute friendly
|
||||
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
|
||||
and the data-movement friendly approach for "decode" (i.e. the ratio
|
||||
Sq / Skv is "large").
|
||||
MLA has two possible ways of computing, a data-movement friendly approach and a
|
||||
compute friendly approach, we generally want to use the compute friendly
|
||||
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
|
||||
and the data-movement friendly approach for "decode" (i.e. the ratio
|
||||
Sq / Skv is "large").
|
||||
|
||||
NOTE what we deem small and large is currently determined by if its labelled
|
||||
prefill or decode by the scheduler, but this is something we should probably
|
||||
NOTE what we deem small and large is currently determined by if its labelled
|
||||
prefill or decode by the scheduler, but this is something we should probably
|
||||
tune.
|
||||
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
Deepseek's MLA attention works the following way:
|
||||
* Use a single latent vector to represent the per-token entry of the KV cache.
|
||||
* For decode (i.e. the memory friendly approach) the attention "simulates" a
|
||||
* Use a single latent vector to represent the per-token entry of the KV cache.
|
||||
* For decode (i.e. the memory friendly approach) the attention "simulates" a
|
||||
multi-head attention, while the compute is similar to multi-query attention.
|
||||
|
||||
Below is example of both paths assuming batchsize = 1
|
||||
@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
|
||||
W_UQ project q_c to q_nope shape [Lq, N * P]
|
||||
W_QR project q_c to q_pe shape [Lq, N * R]
|
||||
W_DKV project h_t to kv_c shape [H, Lkv]
|
||||
W_UK project kv_c to k_nope shape [Lkv, N * P]
|
||||
W_KR project h_t to k_pe shape [H, N * R]
|
||||
W_UV project kv_c to v shape [Lkv, N * V]
|
||||
W_UK project kv_c to k_nope shape [Lkv, N, P]
|
||||
W_KR project h_t to k_pe shape [H, R]
|
||||
W_UV project kv_c to v shape [Lkv, N, V]
|
||||
W_O project v to h_t shape [N * V, H]
|
||||
|
||||
|
||||
@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
|
||||
new_k_pe = RoPE(h_t @ W_KR)
|
||||
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
|
||||
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
|
||||
k_nope = (kv_c @ W_UK).view(Skv, N, P)
|
||||
v = (kv_c @ W_UV).view(Skv, N, V)
|
||||
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
|
||||
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
|
||||
|
||||
// MHA with QK headdim = P + R
|
||||
// V headdim = V
|
||||
@@ -90,20 +90,10 @@ NOTE: in the actual code,
|
||||
|
||||
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
|
||||
|
||||
Ahead of time, compute:
|
||||
|
||||
% this projects from q_c to [Sq, N * Lkv]
|
||||
W_UQ_UK = einsum("qnp,knp -> qnk"
|
||||
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
|
||||
).view(Lkv, N * Lkv)
|
||||
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
|
||||
W_UV_O = einsum("knv,nvh -> nkh"
|
||||
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
|
||||
).view(N * Lkv, H)
|
||||
|
||||
Runtime
|
||||
q_c = h_t @ W_DQ
|
||||
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv)
|
||||
q_nope = (q_c @ W_UQ).view(-1, N, P)
|
||||
ql_nope = einsum("snh,lnh->snl", q, W_UK)
|
||||
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
||||
new_kv_c = h_t @ W_DKV
|
||||
new_k_pe = RoPE(h_t @ W_KR)
|
||||
@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
|
||||
// NOTE: this is less compute-friendly since Lkv > P
|
||||
// but is more data-movement friendly since its MQA vs MHA
|
||||
spda_o = scaled_dot_product_attention(
|
||||
torch.cat([q_latent, q_pe], dim=-1),
|
||||
torch.cat([ql_nope, q_pe], dim=-1),
|
||||
torch.cat([kv_c, k_pe], dim=-1),
|
||||
kv_c
|
||||
)
|
||||
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
|
||||
|
||||
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
|
||||
return o.view(-1, N * V) @ self.num_heads @ W_O
|
||||
|
||||
|
||||
## Chunked Prefill
|
||||
@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
|
||||
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
||||
new_kv_c = h_t @ W_DKV
|
||||
new_k_pe = RoPE(h_t @ W_KR)
|
||||
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P)
|
||||
new_v = (new_kv_c @ W_UV).view(Sq, N, V)
|
||||
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
|
||||
new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
|
||||
|
||||
// MHA between queries and new KV
|
||||
// with QK headdim = P + R
|
||||
@@ -171,17 +163,17 @@ for chunk_idx in range(cdiv(C, MCC)):
|
||||
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
|
||||
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
|
||||
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
|
||||
|
||||
|
||||
chunk_o, chunk_lse = scaled_dot_product_attention(
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
torch.cat([cache_k_nope_chunk,
|
||||
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
|
||||
torch.cat([cache_k_nope_chunk,
|
||||
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
|
||||
dim=-1),
|
||||
cache_v_chunk,
|
||||
casual=False,
|
||||
return_softmax_lse=True
|
||||
)
|
||||
|
||||
|
||||
curr_o, curr_lse = merge_attn_states(
|
||||
suffix_output=curr_o,
|
||||
suffix_lse=curr_lse,
|
||||
@@ -202,7 +194,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
|
||||
Type, TypeVar)
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
@@ -215,20 +206,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
get_flash_attn_version,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
Fp8LinearGenericOp, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_quantize)
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
@@ -1057,7 +1037,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.triton_fa_func = triton_attention
|
||||
self.fp8_linear_generic = Fp8LinearGenericOp()
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
@@ -1070,80 +1049,29 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_UV_O):
|
||||
output_parallel = self.fp8_linear_generic.apply(
|
||||
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape)
|
||||
else:
|
||||
output_parallel = torch.matmul(x.flatten(start_dim=1),
|
||||
self.W_UV_O)
|
||||
if self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
return output
|
||||
else:
|
||||
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
return self.o_proj(x.reshape(-1,
|
||||
self.num_heads * self.v_head_dim))[0]
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return self.o_proj(x)[0]
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_Q_UK):
|
||||
return self.fp8_linear_generic.apply(
|
||||
x, self.W_Q_UK, self.W_Q_UK_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape).view(
|
||||
-1, self.num_heads, self.kv_lora_rank)
|
||||
return torch.matmul(x, self.W_Q_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
else:
|
||||
x = torch.matmul(x, self.W_Q)\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim)
|
||||
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
q_nope, q_pe = self.q_proj(x)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)\
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||
# all the FP8 code with a more standard way of
|
||||
# defining schemes/group-shapes, we should also potentially force
|
||||
# quant_methods to support a decompress function
|
||||
#
|
||||
# returns input_group_shape, weight_group_shape
|
||||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||
Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||
if layer.quant_method.block_quant:
|
||||
weight_block_size = \
|
||||
layer.quant_method.quant_config.weight_block_size
|
||||
# per-token-group (1, X), block-quantized (X, Y)
|
||||
return (1, weight_block_size[-1]), weight_block_size
|
||||
else:
|
||||
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# this is hacky but we always assume the for
|
||||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||
# we ignore if it is static-per-tensor since we are going to
|
||||
# requantize after later anyways
|
||||
strategy = layer.scheme.strategy
|
||||
if strategy == QuantizationStrategy.TENSOR:
|
||||
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||
elif strategy == QuantizationStrategy.CHANNEL:
|
||||
return (1, -1), (-1, 1) # per-token, per-channel
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"QuantizationStrategy.{strategy} is not supported for "
|
||||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can't determine scale group shapes for "
|
||||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||
)
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
@@ -1167,10 +1095,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
weight_dtype = get_layer_weight(self.kv_b_proj).dtype
|
||||
assert get_layer_weight(self.o_proj).dtype == weight_dtype
|
||||
assert get_layer_weight(self.q_proj).dtype == weight_dtype
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
@@ -1189,89 +1116,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
# W_QR is small so for simplicity we dont bother requantizing it
|
||||
self.W_QR = self.W_QR.to(act_dtype)
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
# This assumes it wise to requantize using the same group shapes
|
||||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||
# weights were originally quantized
|
||||
requant_input_group_shape, requant_weight_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
# for decode, as a result we end up with absorbed weights for decode
|
||||
# and another copy of raw weights for prefill.
|
||||
#
|
||||
self.W_UK, self.W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
|
||||
# depending q_lora_rank, the former if q_lora_rank is None, the
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||
W_Q_UK,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform.fp8_dtype())
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||
else:
|
||||
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||
|
||||
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||
W_UV_O,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform.fp8_dtype())
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_UV_O = W_UV_O.T.contiguous()
|
||||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||
else:
|
||||
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
if is_fp8(weight_dtype):
|
||||
raise NotImplementedError(
|
||||
"Currently fp8 requires matrix absorption")
|
||||
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _compute_prefill_context(
|
||||
self,
|
||||
@@ -1471,7 +1319,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
ql_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
@@ -1525,9 +1373,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
||||
|
||||
if has_decode:
|
||||
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
|
||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
@@ -1561,6 +1408,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
if has_decode:
|
||||
output[num_prefill_tokens:] = self._forward_decode(
|
||||
decode_q_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
|
||||
19
vllm/envs.py
19
vllm/envs.py
@@ -84,8 +84,6 @@ if TYPE_CHECKING:
|
||||
VLLM_SERVER_DEV_MODE: bool = False
|
||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||
VLLM_MLA_DISABLE: bool = False
|
||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
|
||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
@@ -563,23 +561,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_MLA_DISABLE":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
|
||||
|
||||
# Flag that can control whether or not we perform matrix-absorption for MLA
|
||||
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
|
||||
# matrices reduces the runtime FLOPs needed to compute MLA but requires
|
||||
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
|
||||
# the is enabled by default
|
||||
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),
|
||||
|
||||
# When running MLA with matrix-absorption enabled and fp8 quantized weights
|
||||
# we perform the matrix-absorption in float32 precision, after the matrices
|
||||
# are absorbed we requantize the weights back to fp8, this flag can be used
|
||||
# to disable the requantization step, and instead convert the absorbed
|
||||
# matrices to match the activation type. This can lead to higher memory and
|
||||
# compute usage but better preserves the accuracy of the original model.
|
||||
"VLLM_MLA_DISABLE_REQUANTIZATION":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0"))),
|
||||
|
||||
# If set, vLLM will use the Triton implementation of moe_align_block_size,
|
||||
# i.e. moe_align_block_size_triton in fused_moe.py.
|
||||
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
|
||||
|
||||
@@ -13,10 +13,9 @@ import triton.language as tl
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
_normalize_quant_group_shape, scaled_dequantize)
|
||||
scaled_dequantize)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED, Fp8LinearOp, cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported)
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@@ -101,60 +100,6 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
# Unify the interface between `apply_w8a8_block_fp8_linear` and
|
||||
# `apply_fp8_linear`
|
||||
# NOTE(lucas): this is quite messy, we should think through this more formally
|
||||
# TODO(luka): unify this better
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class Fp8LinearGenericOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
||||
cutlass_block_fp8_supported: bool = cutlass_block_fp8_supported(),
|
||||
):
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_supported)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_group_shape: Tuple[int, int],
|
||||
weight_group_shape: Tuple[int, int],
|
||||
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
||||
) -> torch.Tensor:
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input = input.view(-1, input.shape[-1])
|
||||
|
||||
weight_group_shape = _normalize_quant_group_shape( \
|
||||
weight, weight_group_shape)
|
||||
input_group_shape = _normalize_quant_group_shape(
|
||||
input, input_group_shape)
|
||||
|
||||
def is_dim_blocked(dim, shape, group_shape):
|
||||
return group_shape < shape[dim] and group_shape > 1
|
||||
|
||||
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
|
||||
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
|
||||
input_group_shape == (1, weight_group_shape[1]):
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input,
|
||||
weight,
|
||||
list(weight_group_shape),
|
||||
weight_scale,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported)
|
||||
else:
|
||||
# Despite having linear in the name it doesn't conform to
|
||||
# `torch.nn.functional.linear` which is defined as
|
||||
# `input @ weight.T` so we explicitly transpose the weight matrix
|
||||
return self.fp8_linear.apply(input, weight.T, weight_scale.T,
|
||||
use_per_token_if_dynamic=\
|
||||
(input_group_shape == (1, input.shape[1])))
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
x: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None
|
||||
|
||||
@@ -21,7 +21,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
Deepseek's MLA attention works the following way:
|
||||
* Use a single latent vector to represent the per-token entry of the KV cache.
|
||||
* Use a single latent vector to represent the per-token entry of the KV cache.
|
||||
* For decode (i.e. the memory friendly approach) the attention "simulates" a
|
||||
multi-head attention, while the compute is similar to multi-query attention.
|
||||
|
||||
@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
|
||||
W_UQ project q_c to q_nope shape [Lq, N * P]
|
||||
W_QR project q_c to q_pe shape [Lq, N * R]
|
||||
W_DKV project h_t to kv_c shape [H, Lkv]
|
||||
W_UK project kv_c to k_nope shape [Lkv, N * P]
|
||||
W_KR project h_t to k_pe shape [H, N * R]
|
||||
W_UV project kv_c to v shape [Lkv, N * V]
|
||||
W_UK project kv_c to k_nope shape [Lkv, N, P]
|
||||
W_KR project h_t to k_pe shape [H, R]
|
||||
W_UV project kv_c to v shape [Lkv, N, V]
|
||||
W_O project v to h_t shape [N * V, H]
|
||||
|
||||
|
||||
@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
|
||||
new_k_pe = RoPE(h_t @ W_KR)
|
||||
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
|
||||
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
|
||||
k_nope = (kv_c @ W_UK).view(Skv, N, P)
|
||||
v = (kv_c @ W_UV).view(Skv, N, V)
|
||||
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
|
||||
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
|
||||
|
||||
// MHA with QK headdim = P + R
|
||||
// V headdim = V
|
||||
@@ -79,7 +79,7 @@ spda_o = scaled_dot_product_attention(
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
|
||||
v
|
||||
)
|
||||
)
|
||||
return spda_o @ W_O
|
||||
|
||||
NOTE: in the actual code,
|
||||
@@ -90,20 +90,10 @@ NOTE: in the actual code,
|
||||
|
||||
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
|
||||
|
||||
Ahead of time, compute:
|
||||
|
||||
% this projects from q_c to [Sq, N * Lkv]
|
||||
W_UQ_UK = einsum("qnp,knp -> qnk"
|
||||
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
|
||||
).view(Lkv, N * Lkv)
|
||||
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
|
||||
W_UV_O = einsum("knv,nvh -> nkh"
|
||||
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
|
||||
).view(N * Lkv, H)
|
||||
|
||||
Runtime
|
||||
q_c = h_t @ W_DQ
|
||||
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv)
|
||||
q_nope = (q_c @ W_UQ).view(-1, N, P)
|
||||
ql_nope = einsum("snh,lnh->snl", q, W_UK)
|
||||
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
||||
new_kv_c = h_t @ W_DKV
|
||||
new_k_pe = RoPE(h_t @ W_KR)
|
||||
@@ -116,29 +106,31 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
|
||||
// NOTE: this is less compute-friendly since Lkv > P
|
||||
// but is more data-movement friendly since its MQA vs MHA
|
||||
spda_o = scaled_dot_product_attention(
|
||||
torch.cat([q_latent, q_pe], dim=-1),
|
||||
torch.cat([ql_nope, q_pe], dim=-1),
|
||||
torch.cat([kv_c, k_pe], dim=-1),
|
||||
kv_c
|
||||
)
|
||||
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
|
||||
|
||||
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
|
||||
return o.view(-1, N * V) @ self.num_heads @ W_O
|
||||
|
||||
|
||||
## Chunked Prefill
|
||||
|
||||
For chunked prefill we want to use the compute friendly algorithm. We are
|
||||
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
|
||||
For chunked prefill we want to use the compute friendly algorithm. We are
|
||||
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
|
||||
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
|
||||
|
||||
However, the compute-friendly approach can potentially run out of memory if Skv
|
||||
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
|
||||
|
||||
To mitigate this, we chunk the computation of attention with respect to the
|
||||
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
|
||||
To mitigate this, we chunk the computation of attention with respect to the
|
||||
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
|
||||
fixed workspace size.
|
||||
|
||||
The chunked prefill approach is as follows:
|
||||
|
||||
MCC Max chunk of context to process per iter, computed dynamically,
|
||||
MCC Max chunk of context to process per iter, computed dynamically,
|
||||
used to bound the memory usage
|
||||
|
||||
q_c = h_t @ W_DQ
|
||||
@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
|
||||
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
||||
new_kv_c = h_t @ W_DKV
|
||||
new_k_pe = RoPE(h_t @ W_KR)
|
||||
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P)
|
||||
new_v = (new_kv_c @ W_UV).view(Sq, N, V)
|
||||
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
|
||||
new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
|
||||
|
||||
// MHA between queries and new KV
|
||||
// with QK headdim = P + R
|
||||
@@ -160,7 +152,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
|
||||
new_v,
|
||||
casual=True,
|
||||
return_softmax_lse=True
|
||||
)
|
||||
)
|
||||
|
||||
// Compute attention with the already existing context
|
||||
for chunk_idx in range(cdiv(C, MCC)):
|
||||
@@ -198,30 +190,17 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
Fp8LinearGenericOp, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_quantize)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
@@ -646,7 +625,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.fp8_linear_generic = Fp8LinearGenericOp()
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
@@ -658,88 +636,37 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_UV_O):
|
||||
output_parallel = self.fp8_linear_generic.apply(
|
||||
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape)
|
||||
else:
|
||||
output_parallel = torch.matmul(x.flatten(start_dim=1),
|
||||
self.W_UV_O)
|
||||
if self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
return output
|
||||
else:
|
||||
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
return self.o_proj(x.reshape(-1,
|
||||
self.num_heads * self.v_head_dim))[0]
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return self.o_proj(x)[0]
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_Q_UK):
|
||||
return self.fp8_linear_generic.apply(
|
||||
x, self.W_Q_UK, self.W_Q_UK_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape).view(
|
||||
-1, self.num_heads, self.kv_lora_rank)
|
||||
return torch.matmul(x, self.W_Q_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
else:
|
||||
x = torch.matmul(x, self.W_Q)\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim)
|
||||
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
q_nope, q_pe = self.q_proj(x)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)\
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||
# all the FP8 code with a more standard way of
|
||||
# defining schemes/group-shapes, we should also potentially force
|
||||
# quant_methods to support a decompress function
|
||||
#
|
||||
# returns input_group_shape, weight_group_shape
|
||||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||
tuple[tuple[int, int], tuple[int, int]]:
|
||||
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||
if layer.quant_method.block_quant:
|
||||
weight_block_size = \
|
||||
layer.quant_method.quant_config.weight_block_size
|
||||
# per-token-group (1, X), block-quantized (X, Y)
|
||||
return (1, weight_block_size[-1]), weight_block_size
|
||||
else:
|
||||
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# this is hacky but we always assume the for
|
||||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||
# we ignore if it is static-per-tensor since we are going to
|
||||
# requantize after later anyways
|
||||
strategy = layer.scheme.strategy
|
||||
if strategy == QuantizationStrategy.TENSOR:
|
||||
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||
elif strategy == QuantizationStrategy.CHANNEL:
|
||||
return (1, -1), (-1, 1) # per-token, per-channel
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"QuantizationStrategy.{strategy} is not supported for "
|
||||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can't determine scale group shapes for "
|
||||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||
)
|
||||
|
||||
def get_layer_weight(layer):
|
||||
if hasattr(layer, "weight"):
|
||||
return layer.weight
|
||||
elif hasattr(layer, "qweight"):
|
||||
return layer.qweight
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has neither weight nor qweight")
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
@@ -755,10 +682,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
weight_dtype = get_layer_weight(self.kv_b_proj).dtype
|
||||
assert get_layer_weight(self.o_proj).dtype == weight_dtype
|
||||
assert get_layer_weight(self.q_proj).dtype == weight_dtype
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
@@ -777,89 +703,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
# W_QR is small so for simplicity we dont bother requantizing it
|
||||
self.W_QR = self.W_QR.to(act_dtype)
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
# This assumes it wise to requantize using the same group shapes
|
||||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||
# weights were originally quantized
|
||||
requant_input_group_shape, requant_weight_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
# for decode, as a result we end up with absorbed weights for decode
|
||||
# and another copy of raw weights for prefill.
|
||||
#
|
||||
self.W_UK, self.W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
|
||||
# depending q_lora_rank, the former if q_lora_rank is None, the
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||
W_Q_UK,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform.fp8_dtype())
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||
else:
|
||||
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||
|
||||
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||
W_UV_O,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform.fp8_dtype())
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_UV_O = W_UV_O.T.contiguous()
|
||||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||
else:
|
||||
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
if is_fp8(weight_dtype):
|
||||
raise NotImplementedError(
|
||||
"Currently fp8 requires matrix absorption")
|
||||
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _compute_prefill_context(
|
||||
self,
|
||||
@@ -998,7 +845,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
ql_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
@@ -1051,10 +898,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
|
||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
|
||||
decode_k_pe)
|
||||
@@ -1087,6 +932,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
if has_decode:
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_q_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output_padded
|
||||
|
||||
Reference in New Issue
Block a user