[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (#8061)

This commit is contained in:
Isotr0py
2024-09-03 21:37:52 +08:00
committed by GitHub
parent 0fbc6696c2
commit ec266536b7
5 changed files with 157 additions and 44 deletions

View File

@@ -10,7 +10,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from xformers import ops as xops
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@@ -21,6 +20,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
@@ -81,7 +86,7 @@ class InternVisionEmbeddings(nn.Module):
return embeddings
class InternAttention(nn.Module):
class InternParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
@@ -140,18 +145,67 @@ class InternAttention(nn.Module):
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
x = xops.memory_efficient_attention_forward(
q,
k,
v,
scale=self.scale,
)
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)
x, _ = self.proj(x)
return x
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim,
bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).view(B, N, -1)
x = self.proj(x)
return x
class InternMLP(nn.Module):
def __init__(self,
@@ -187,7 +241,14 @@ class InternVisionEncoderLayer(nn.Module):
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
self.attn = InternAttention(config, quant_config=quant_config)
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.attn = InternParallelAttention(config,
quant_config=quant_config)
else:
self.attn = InternSdpaAttention(config)
self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)