[Model] Consolidate ViTs attention implementation without mask (#10893)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-12-05 02:11:08 +08:00
committed by GitHub
parent 01d079fd8e
commit 10398b4706
9 changed files with 107 additions and 224 deletions

View File

@@ -21,8 +21,8 @@ import torch
from torch import nn
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig)
from xformers import ops as xops
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -141,35 +141,18 @@ class Idefics2VisionAttention(nn.Module):
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.is_causal = False
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv, _ = self.qkv_proj(
hidden_states
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
# see: https://facebookresearch.github.io/xformers/components/ops.html
out = xops.memory_efficient_attention_forward(
query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale,
)
out = out.view(batch_size, q_len, -1)
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output