[Model] Consolidate ViTs attention implementation without mask (#10893)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -21,8 +21,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers.models.idefics2.configuration_idefics2 import (
|
||||
Idefics2Config, Idefics2VisionConfig)
|
||||
from xformers import ops as xops
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@@ -141,35 +141,18 @@ class Idefics2VisionAttention(nn.Module):
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
self.is_causal = False
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
qkv, _ = self.qkv_proj(
|
||||
hidden_states
|
||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||
query_states = query_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
# see: https://facebookresearch.github.io/xformers/components/ops.html
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale,
|
||||
)
|
||||
out = out.view(batch_size, q_len, -1)
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
return attn_output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user