[Attention] Move MLA forward from backend to layer (#33284)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -63,7 +63,7 @@ W_UV project kv_c to v shape [Lkv, N, V]
|
||||
W_O project v to h_t shape [N * V, H]
|
||||
|
||||
|
||||
## Compute Friendly Approach (i.e. "_forward_prefill"):
|
||||
## Compute Friendly Approach (i.e. "forward_mha"):
|
||||
|
||||
q_c = h_t @ W_DQ
|
||||
q_nope = (q_c @ W_UQ).view(Sq, N, P)
|
||||
@@ -91,7 +91,7 @@ NOTE: in the actual code,
|
||||
`out_proj` is W_O
|
||||
|
||||
|
||||
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
|
||||
## Data-Movement Friendly Approach (i.e. "forward_mqa"):
|
||||
|
||||
Runtime
|
||||
q_c = h_t @ W_DQ
|
||||
@@ -243,6 +243,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MLAAttentionImpl,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@@ -266,6 +267,9 @@ logger = init_logger(__name__)
|
||||
class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
"""Multi-Head Latent Attention layer.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
|
||||
This class takes query, and compressed key/value tensors as input.
|
||||
The class does the following:
|
||||
|
||||
@@ -289,6 +293,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
prefix: str = "",
|
||||
use_sparse: bool = False,
|
||||
indexer: object | None = None,
|
||||
q_pad_num_heads: int | None = None,
|
||||
**extra_impl_args,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -299,8 +304,14 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.head_size = kv_lora_rank + qk_rope_head_dim
|
||||
self.layer_name = prefix
|
||||
self.indexer = indexer
|
||||
self.q_pad_num_heads = q_pad_num_heads
|
||||
|
||||
self.num_kv_heads = 1
|
||||
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
@@ -364,6 +375,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
v_head_dim=self.v_head_dim,
|
||||
kv_b_proj=kv_b_proj,
|
||||
indexer=indexer,
|
||||
q_pad_num_heads=q_pad_num_heads,
|
||||
**extra_impl_args,
|
||||
)
|
||||
|
||||
@@ -388,6 +400,26 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||
|
||||
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
|
||||
self.is_aiter_triton_fp4_bmm_enabled = (
|
||||
rocm_aiter_ops.is_fp4bmm_enabled()
|
||||
and self.kv_b_proj.weight.dtype == torch.bfloat16
|
||||
)
|
||||
|
||||
# Attributes for forward_impl method
|
||||
self.chunked_prefill_workspace_size = (
|
||||
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
|
||||
get_current_vllm_config()
|
||||
)
|
||||
)
|
||||
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
compile_native=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -407,8 +439,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
self.impl.forward(
|
||||
self,
|
||||
self.forward_impl(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
@@ -418,8 +449,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return self.impl.forward(
|
||||
self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
|
||||
return self.forward_impl(
|
||||
q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
|
||||
)
|
||||
else:
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
@@ -440,9 +471,282 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.layer_name,
|
||||
)
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "MLACommonMetadata",
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for MLA"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# During the profile run try to simulate to worse case output size
|
||||
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
||||
# since this can be large
|
||||
_ = torch.empty(
|
||||
(
|
||||
self.chunked_prefill_workspace_size,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
),
|
||||
device=k_c_normed.device,
|
||||
dtype=k_c_normed.dtype,
|
||||
)
|
||||
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
if self.impl.dcp_world_size == -1:
|
||||
self.impl.dcp_world_size = get_dcp_group().world_size
|
||||
|
||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_toks, ...]
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
assert (
|
||||
attn_metadata.num_decodes is not None
|
||||
and attn_metadata.num_prefills is not None
|
||||
and attn_metadata.num_decode_tokens is not None
|
||||
)
|
||||
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
decode_q = q[:num_decode_tokens]
|
||||
|
||||
prefill_q = q[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=self._k_scale,
|
||||
)
|
||||
|
||||
if fp8_attention:
|
||||
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
# Sparse MLA impls only support forward_mqa (decode-style attention)
|
||||
is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl)
|
||||
|
||||
if has_prefill and not is_sparse_impl:
|
||||
self.impl.forward_mha(
|
||||
prefill_q,
|
||||
prefill_k_c_normed,
|
||||
prefill_k_pe,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
self._k_scale,
|
||||
output=output[num_decode_tokens:],
|
||||
)
|
||||
|
||||
if has_decode or (has_prefill and is_sparse_impl):
|
||||
# For sparse impl, we always use forward_mqa for all tokens
|
||||
# For non-sparse impl, we only use forward_mqa for decode tokens
|
||||
if is_sparse_impl:
|
||||
mqa_q = q
|
||||
mqa_output_slice = output
|
||||
else:
|
||||
assert attn_metadata.decode is not None
|
||||
mqa_q = decode_q
|
||||
mqa_output_slice = output[:num_decode_tokens]
|
||||
|
||||
mqa_q_nope, mqa_q_pe = mqa_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
mqa_q_nope = mqa_q_nope.transpose(0, 1)
|
||||
|
||||
if self.q_pad_num_heads is not None:
|
||||
B, N, L = mqa_q_pe.shape
|
||||
mqa_pe_padded = mqa_q_pe.new_empty((B, self.q_pad_num_heads, L))
|
||||
mqa_pe_padded.resize_((B, N, L))
|
||||
mqa_pe_padded.copy_(mqa_q_pe)
|
||||
mqa_q_pe = mqa_pe_padded
|
||||
|
||||
if self.is_aiter_triton_fp4_bmm_enabled:
|
||||
from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4
|
||||
|
||||
mqa_ql_nope = batched_gemm_a16wfp4(
|
||||
mqa_q_nope,
|
||||
self.W_K,
|
||||
self.W_K_scale,
|
||||
transpose_bm=True,
|
||||
prequant=True,
|
||||
y_scale=self._q_scale if fp8_attention else None,
|
||||
)
|
||||
elif self.is_aiter_triton_fp8_bmm_enabled:
|
||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||
mqa_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
|
||||
mqa_q_nope,
|
||||
self.W_K,
|
||||
self.W_K_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True,
|
||||
)
|
||||
else:
|
||||
# Pads the head_dim if necessary (for the underlying kernel)
|
||||
N, B, P = mqa_q_nope.shape
|
||||
_, _, L = self.W_UK_T.shape
|
||||
|
||||
if self.q_pad_num_heads is not None:
|
||||
mqa_ql_nope = mqa_q_nope.new_empty((self.q_pad_num_heads, B, L))
|
||||
mqa_ql_nope.resize_((N, B, L))
|
||||
else:
|
||||
mqa_ql_nope = mqa_q_nope.new_empty((N, B, L))
|
||||
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
torch.bmm(mqa_q_nope, self.W_UK_T, out=mqa_ql_nope)
|
||||
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)
|
||||
|
||||
if fp8_attention:
|
||||
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
|
||||
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
|
||||
mqa_q = self._decode_concat_quant_fp8_op(
|
||||
mqa_ql_nope, mqa_q_pe, self._q_scale
|
||||
)
|
||||
else:
|
||||
mqa_q = (mqa_ql_nope, mqa_q_pe)
|
||||
if self.impl.dcp_world_size > 1:
|
||||
assert not fp8_attention, "DCP not support fp8 kvcache now."
|
||||
# concatenate mqa_ql_nope and mqa_q_pe -> (B, N, L + P)
|
||||
mqa_q = torch.cat(mqa_q, dim=-1)
|
||||
# mqa_q do allgather in head dim.
|
||||
mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)
|
||||
|
||||
# call decode attn
|
||||
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
|
||||
|
||||
# correct dcp attn_out with lse.
|
||||
if self.impl.dcp_world_size > 1:
|
||||
attn_out = cp_lse_ag_out_rs(
|
||||
attn_out,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||
)
|
||||
|
||||
# v_up projection
|
||||
self._v_up_proj(attn_out, out=mqa_output_slice)
|
||||
return output_padded
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(
|
||||
self.kv_b_proj, out_dtype=act_dtype
|
||||
).T
|
||||
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}"
|
||||
)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
|
||||
if self.is_aiter_triton_fp4_bmm_enabled:
|
||||
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
quark_quantize_weight_to_mxfp4,
|
||||
)
|
||||
|
||||
self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK)
|
||||
# Convert from (L, N, P) to (N, L, P)
|
||||
self.W_K = self.W_K.transpose(0, 1)
|
||||
self.W_K_scale = self.W_K_scale.transpose(0, 1)
|
||||
|
||||
self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4(
|
||||
W_UV.permute(1, 2, 0)
|
||||
)
|
||||
elif self.is_aiter_triton_fp8_bmm_enabled:
|
||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||
W_K, dtype=current_platform.fp8_dtype()
|
||||
)
|
||||
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
|
||||
W_V, dtype=current_platform.fp8_dtype()
|
||||
)
|
||||
|
||||
# The kernel operates on non-padded inputs. Hence, pre-compiling
|
||||
# triton kernel to avoid runtime compilation for unseen batch sizes
|
||||
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
|
||||
# On DS-R1, this step adds roughly 50s to the model loading time.
|
||||
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
|
||||
pre_compilation_list = list(range(1, max_batch_size + 1))
|
||||
if is_global_first_rank():
|
||||
pre_compilation_list = tqdm(
|
||||
pre_compilation_list,
|
||||
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
|
||||
total=max_batch_size,
|
||||
)
|
||||
|
||||
for m in pre_compilation_list:
|
||||
x = torch.empty(
|
||||
(self.W_K.shape[0], m, self.W_K.shape[2]),
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_K.device,
|
||||
)
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
|
||||
x = torch.empty(
|
||||
(self.W_V.shape[0], m, self.W_V.shape[2]),
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_V.device,
|
||||
)
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
else:
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
# If we should not load quant weights, we initialize the scales to 1.0
|
||||
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
|
||||
@@ -492,6 +796,41 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||
)
|
||||
|
||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
out = out.view(-1, self.num_heads, self.v_head_dim)
|
||||
if self.is_aiter_triton_fp4_bmm_enabled:
|
||||
out = rocm_aiter_ops.batched_gemm_a16wfp4(
|
||||
x,
|
||||
self.W_V,
|
||||
self.W_V_scale,
|
||||
out,
|
||||
transpose_bm=True,
|
||||
prequant=True,
|
||||
y_scale=None,
|
||||
)
|
||||
x = out.view(-1, self.num_heads * self.v_head_dim)
|
||||
elif self.is_aiter_triton_fp8_bmm_enabled:
|
||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||
x = rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
|
||||
)
|
||||
else:
|
||||
# Convert from (B, N * V) to (N, B, V)
|
||||
out = out.transpose(0, 1)
|
||||
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
||||
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
# Adjust output buffer shape back to the original (B, N * V)
|
||||
N, B, V = out.shape
|
||||
out.resize_((B, N * V))
|
||||
out.copy_(out_new) # Copy result
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_mla_attention(
|
||||
@@ -500,8 +839,8 @@ def unified_mla_attention(
|
||||
k_pe: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
|
||||
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
|
||||
@@ -534,9 +873,8 @@ def unified_mla_attention_with_output(
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
self.impl.forward(
|
||||
self,
|
||||
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
|
||||
layer.forward_impl(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
@@ -1511,9 +1849,7 @@ def reorg_kvcache(
|
||||
return reorganized_kv_c_normed, reorganized_k_pe
|
||||
|
||||
|
||||
# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
|
||||
# and MLACommonImpl -> MLACommonDenseImpl or somthing like that
|
||||
class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@@ -1539,7 +1875,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
indexer=None,
|
||||
indexer: object | None = None,
|
||||
q_pad_num_heads: int | None = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
@@ -1560,147 +1896,6 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.indexer = indexer
|
||||
self.q_pad_num_heads = q_pad_num_heads
|
||||
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||
|
||||
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
|
||||
self.is_aiter_triton_fp4_bmm_enabled = (
|
||||
rocm_aiter_ops.is_fp4bmm_enabled()
|
||||
and self.kv_b_proj.weight.dtype == torch.bfloat16
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(
|
||||
self.kv_b_proj, out_dtype=act_dtype
|
||||
).T
|
||||
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}"
|
||||
)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
|
||||
if self.is_aiter_triton_fp4_bmm_enabled:
|
||||
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
quark_quantize_weight_to_mxfp4,
|
||||
)
|
||||
|
||||
self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK)
|
||||
# Convert from (L, N, P) to (N, L, P)
|
||||
self.W_K = self.W_K.transpose(0, 1)
|
||||
self.W_K_scale = self.W_K_scale.transpose(0, 1)
|
||||
|
||||
self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4(
|
||||
W_UV.permute(1, 2, 0)
|
||||
)
|
||||
elif self.is_aiter_triton_fp8_bmm_enabled:
|
||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||
W_K, dtype=current_platform.fp8_dtype()
|
||||
)
|
||||
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
|
||||
W_V, dtype=current_platform.fp8_dtype()
|
||||
)
|
||||
|
||||
# The kernel operates on non-padded inputs. Hence, pre-compiling
|
||||
# triton kernel to avoid runtime compilation for unseen batch sizes
|
||||
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
|
||||
# On DS-R1, this step adds roughly 50s to the model loading time.
|
||||
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
|
||||
pre_compilation_list = list(range(1, max_batch_size + 1))
|
||||
if is_global_first_rank():
|
||||
pre_compilation_list = tqdm(
|
||||
pre_compilation_list,
|
||||
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
|
||||
total=max_batch_size,
|
||||
)
|
||||
|
||||
for m in pre_compilation_list:
|
||||
x = torch.empty(
|
||||
(self.W_K.shape[0], m, self.W_K.shape[2]),
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_K.device,
|
||||
)
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
|
||||
x = torch.empty(
|
||||
(self.W_V.shape[0], m, self.W_V.shape[2]),
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_V.device,
|
||||
)
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
else:
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
out = out.view(-1, self.num_heads, self.v_head_dim)
|
||||
if self.is_aiter_triton_fp4_bmm_enabled:
|
||||
out = rocm_aiter_ops.batched_gemm_a16wfp4(
|
||||
x,
|
||||
self.W_V,
|
||||
self.W_V_scale,
|
||||
out,
|
||||
transpose_bm=True,
|
||||
prequant=True,
|
||||
y_scale=None,
|
||||
)
|
||||
x = out.view(-1, self.num_heads * self.v_head_dim)
|
||||
elif self.is_aiter_triton_fp8_bmm_enabled:
|
||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||
x = rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
|
||||
)
|
||||
else:
|
||||
# Convert from (B, N * V) to (N, B, V)
|
||||
out = out.transpose(0, 1)
|
||||
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
||||
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
# Adjust output buffer shape back to the original (B, N * V)
|
||||
N, B, V = out.shape
|
||||
out.resize_((B, N * V))
|
||||
out.copy_(out_new) # Copy result
|
||||
|
||||
|
||||
class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if use_trtllm_ragged_deepseek_prefill():
|
||||
logger.info_once(
|
||||
@@ -1750,19 +1945,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
self.dcp_world_size: int = -1
|
||||
|
||||
self.chunked_prefill_workspace_size = (
|
||||
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
|
||||
get_current_vllm_config()
|
||||
)
|
||||
)
|
||||
self.cp_kv_cache_interleave_size: int = (
|
||||
get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
|
||||
)
|
||||
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
compile_native=True,
|
||||
)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
@@ -2193,7 +2378,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
return output, output_lse
|
||||
|
||||
def _forward_prefill(
|
||||
def forward_mha(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
@@ -2258,7 +2443,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
output.copy_(output_prefill)
|
||||
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
@@ -2266,185 +2451,3 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for MLACommonImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# During the profile run try to simulate to worse case output size
|
||||
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
||||
# since this can be large
|
||||
_ = torch.empty(
|
||||
(
|
||||
self.chunked_prefill_workspace_size,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
),
|
||||
device=k_c_normed.device,
|
||||
dtype=k_c_normed.dtype,
|
||||
)
|
||||
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
if self.dcp_world_size == -1:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
|
||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_toks, ...]
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
assert (
|
||||
attn_metadata.num_decodes is not None
|
||||
and attn_metadata.num_prefills is not None
|
||||
and attn_metadata.num_decode_tokens is not None
|
||||
)
|
||||
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
decode_q = q[:num_decode_tokens]
|
||||
|
||||
prefill_q = q[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if fp8_attention:
|
||||
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
if has_prefill:
|
||||
self._forward_prefill(
|
||||
prefill_q,
|
||||
prefill_k_c_normed,
|
||||
prefill_k_pe,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
layer._k_scale,
|
||||
output=output[num_decode_tokens:],
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
|
||||
if self.q_pad_num_heads is not None:
|
||||
B, N, L = decode_q_pe.shape
|
||||
decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
|
||||
decode_pe_padded.resize_((B, N, L))
|
||||
decode_pe_padded.copy_(decode_q_pe)
|
||||
decode_q_pe = decode_pe_padded
|
||||
|
||||
if self.is_aiter_triton_fp4_bmm_enabled:
|
||||
from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4
|
||||
|
||||
decode_ql_nope = batched_gemm_a16wfp4(
|
||||
decode_q_nope,
|
||||
self.W_K,
|
||||
self.W_K_scale,
|
||||
transpose_bm=True,
|
||||
prequant=True,
|
||||
y_scale=layer._q_scale if fp8_attention else None,
|
||||
)
|
||||
elif self.is_aiter_triton_fp8_bmm_enabled:
|
||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
|
||||
decode_q_nope,
|
||||
self.W_K,
|
||||
self.W_K_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True,
|
||||
)
|
||||
else:
|
||||
# Pads the head_dim if necessary (for the underlying kernel)
|
||||
N, B, P = decode_q_nope.shape
|
||||
_, _, L = self.W_UK_T.shape
|
||||
|
||||
if self.q_pad_num_heads is not None:
|
||||
decode_ql_nope = decode_q_nope.new_empty(
|
||||
(self.q_pad_num_heads, B, L)
|
||||
)
|
||||
decode_ql_nope.resize_((N, B, L))
|
||||
else:
|
||||
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
|
||||
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
|
||||
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
|
||||
if fp8_attention:
|
||||
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
|
||||
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
|
||||
decode_q = self._decode_concat_quant_fp8_op(
|
||||
decode_ql_nope, decode_q_pe, layer._q_scale
|
||||
)
|
||||
else:
|
||||
decode_q = (decode_ql_nope, decode_q_pe)
|
||||
if self.dcp_world_size > 1:
|
||||
assert not fp8_attention, "DCP not support fp8 kvcache now."
|
||||
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
|
||||
decode_q = torch.cat(decode_q, dim=-1)
|
||||
# decode_q do allgather in head dim.
|
||||
decode_q = get_dcp_group().all_gather(decode_q, dim=1)
|
||||
|
||||
# call decode attn
|
||||
attn_out, lse = self._forward_decode(
|
||||
decode_q, kv_cache, attn_metadata, layer
|
||||
)
|
||||
|
||||
# correct dcp attn_out with lse.
|
||||
if self.dcp_world_size > 1:
|
||||
attn_out = cp_lse_ag_out_rs(
|
||||
attn_out,
|
||||
lse,
|
||||
get_dcp_group(),
|
||||
is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
|
||||
)
|
||||
|
||||
# v_up projection
|
||||
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
|
||||
return output_padded
|
||||
|
||||
Reference in New Issue
Block a user