[Model] Consolidate ViTs attention implementation without mask (#10893)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-12-05 02:11:08 +08:00
committed by GitHub
parent 01d079fd8e
commit 10398b4706
9 changed files with 107 additions and 224 deletions

View File

@@ -12,7 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.attention.selector import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
@@ -25,8 +25,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .utils import get_vit_attn_backend
NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
@@ -183,10 +181,8 @@ class InternParallelAttention(nn.Module):
prefix=f"{prefix}.proj",
)
self.attn_backend = get_vit_attn_backend(support_fa=False)
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
raise RuntimeError(
f"InternViT does not support {self.attn_backend} backend now.")
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
@@ -209,23 +205,7 @@ class InternParallelAttention(nn.Module):
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(q,
k,
v,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(1, 2)
out = out.view(B, N, -1)
out = self.attn(q, k, v)
out, _ = self.proj(out)
return out