[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models (#23154)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -32,6 +32,7 @@ from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@@ -159,7 +160,9 @@ class Qwen2Attention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(
|
||||
attn_cls = (EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
|
||||
Reference in New Issue
Block a user