Separate MLAAttention class from Attention (#25103)

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Naveenraj Kamalakannan
2025-10-08 20:11:11 -04:00
committed by GitHub
parent 2a03f93de9
commit e614ab7806
10 changed files with 502 additions and 163 deletions

View File

@@ -14,6 +14,7 @@ from torch import nn
from typing_extensions import assert_never
from vllm.attention import Attention
from vllm.attention.layer import MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
@@ -122,11 +123,10 @@ def process_weights_after_loading(
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
# Initialize post-load attention weights for both Attention and MLA.
# NOTE: Happens after other modules so we can easily decompress weights.
for _, module in model.named_modules():
if isinstance(module, Attention) and hasattr(
if isinstance(module, (Attention, MLAAttention)) and hasattr(
module, "process_weights_after_loading"
):
# TODO(lucas): see if there is a way to unify the signatures