Separate MLAAttention class from Attention (#25103)
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
2a03f93de9
commit
e614ab7806
@@ -14,6 +14,7 @@ from torch import nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
||||
@@ -122,11 +123,10 @@ def process_weights_after_loading(
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
# Currently only used by MLA.
|
||||
# NOTE: This intentionally happens after other modules so we can easily
|
||||
# decompress the weights for MLA.
|
||||
# Initialize post-load attention weights for both Attention and MLA.
|
||||
# NOTE: Happens after other modules so we can easily decompress weights.
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, Attention) and hasattr(
|
||||
if isinstance(module, (Attention, MLAAttention)) and hasattr(
|
||||
module, "process_weights_after_loading"
|
||||
):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
|
||||
Reference in New Issue
Block a user