[Model] Consolidate ViTs attention implementation without mask (#10893)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-12-05 02:11:08 +08:00
committed by GitHub
parent 01d079fd8e
commit 10398b4706
9 changed files with 107 additions and 224 deletions

View File

@@ -6,12 +6,11 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig
from vllm.attention.selector import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
@@ -29,8 +28,6 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
@@ -291,52 +288,18 @@ class SiglipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"SIGLIP does not support {self.attn_backend} backend now.")
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(batch_size, q_len, -1)
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output, None