[Model] Remove transformers attention porting in VITs (#10414)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-18 21:45:21 +08:00
committed by GitHub
parent 5be4e52b65
commit e7ebb662d7
7 changed files with 139 additions and 102 deletions

View File

@@ -6,11 +6,12 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
from vllm.attention.selector import _Backend
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
@@ -27,11 +28,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
from .utils import get_vit_attn_backend
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
@@ -254,7 +251,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings
class SiglipParallelAttention(nn.Module):
class SiglipAttention(nn.Module):
def __init__(
self,
@@ -293,6 +290,11 @@ class SiglipParallelAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"SIGLIP does not support {self.attn_backend} backend now.")
def forward(
self,
hidden_states: torch.Tensor,
@@ -313,11 +315,26 @@ class SiglipParallelAttention(nn.Module):
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query_states, key_states, value_states = (x.transpose(1, 2)
for x in (query_states,
key_states,
value_states))
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
dropout_p=self.dropout,
scale=self.scale)
out = out.transpose(1, 2)
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)
@@ -372,17 +389,11 @@ class SiglipEncoderLayer(nn.Module):
self.embed_dim = config.hidden_size
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = SiglipSdpaAttention(config)
self.self_attn = SiglipAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@@ -569,10 +580,6 @@ class SiglipVisionModel(nn.Module):
) -> None:
super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = SiglipVisionTransformer(
config,
quant_config,
@@ -601,7 +608,7 @@ class SiglipVisionModel(nn.Module):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers)