[Model] Remove transformers attention porting in VITs (#10414)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -12,6 +12,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
@@ -24,11 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
from .utils import get_vit_attn_backend
|
||||
|
||||
NORM2FN = {
|
||||
'rms_norm': RMSNorm,
|
||||
@@ -186,6 +183,11 @@ class InternParallelAttention(nn.Module):
|
||||
prefix=f"{prefix}.proj",
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(support_fa=False)
|
||||
if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
|
||||
raise RuntimeError(
|
||||
f"InternViT does not support {self.attn_backend} backend now.")
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||
if self.tp_size > 1:
|
||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||
@@ -211,11 +213,21 @@ class InternParallelAttention(nn.Module):
|
||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
|
||||
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
|
||||
x = x.view(B, N, -1)
|
||||
if self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
|
||||
x, _ = self.proj(x)
|
||||
return x
|
||||
out = xops.memory_efficient_attention_forward(q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
|
||||
out = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out = out.view(B, N, -1)
|
||||
out, _ = self.proj(out)
|
||||
return out
|
||||
|
||||
|
||||
class InternSdpaAttention(nn.Module):
|
||||
@@ -362,7 +374,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
|
||||
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
|
||||
if (num_heads + num_dummy_heads) % tp_size == 0:
|
||||
return InternParallelAttention(config,
|
||||
quant_config=quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
|
||||
Reference in New Issue
Block a user