Separate MLAAttention class from Attention (#25103)

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Naveenraj Kamalakannan
2025-10-08 20:11:11 -04:00
committed by GitHub
parent 2a03f93de9
commit e614ab7806
10 changed files with 502 additions and 163 deletions

View File

@@ -5,7 +5,7 @@ from typing import Optional
import torch
from vllm.attention import Attention
from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -30,8 +30,9 @@ class MLAModules:
@CustomOp.register("multi_head_latent_attention")
class MultiHeadLatentAttention(CustomOp):
"""MLA layer registered as CustomOp.
class MultiHeadLatentAttentionWrapper(CustomOp):
"""MLA layer registered as CustomOp to allow OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently MLA ignores the enable/disable mechanism of CustomOp
because there is only one in-tree implementation in forward_native.
TODO: implement this with a new PluggableLayer mechanism.
@@ -87,30 +88,19 @@ class MultiHeadLatentAttention(CustomOp):
self.topk_tokens = self.indexer.topk_tokens
self.topk_indices_buffer = mla_modules.topk_indices_buffer
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
self.mla_attn = MLAAttention(
num_heads=self.num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj,
use_sparse=self.is_sparse,
indexer=self.indexer,
)