Feature/vit attention unification# 23880 (#23978)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -16,6 +16,7 @@ from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@@ -682,9 +683,9 @@ class Step3VisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2).contiguous()
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
|
||||
self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -696,19 +697,9 @@ class Step3VisionAttention(nn.Module):
|
||||
# get query proj
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
k = k.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
v = v.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
attn_output = F.scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale,
|
||||
is_causal=False)
|
||||
attn_output = attn_output.transpose(1, 2).reshape(
|
||||
bsz, tgt_len, self.num_heads * self.head_dim)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user