[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-08-21 22:05:59 -07:00
committed by GitHub
parent 5964069367
commit 17373dcd93
12 changed files with 226 additions and 214 deletions

View File

@@ -31,6 +31,7 @@ from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -173,7 +174,10 @@ class LlamaAttention(nn.Module):
if is_sliding:
sliding_window = config.sliding_window
self.attn = Attention(
attn_cls = (EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY else Attention)
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,