[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (#8061)

This commit is contained in:
Isotr0py
2024-09-03 21:37:52 +08:00
committed by GitHub
parent 0fbc6696c2
commit ec266536b7
5 changed files with 157 additions and 44 deletions

View File

@@ -9,7 +9,7 @@ import torch
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig
from xformers import ops as xops
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -26,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
@@ -219,7 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings
class SiglipAttention(nn.Module):
class SiglipParallelAttention(nn.Module):
def __init__(
self,
@@ -282,7 +288,7 @@ class SiglipAttention(nn.Module):
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output
return attn_output, None
class SiglipMLP(nn.Module):
@@ -327,7 +333,14 @@ class SiglipEncoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config, quant_config=quant_config)
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = SiglipSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@@ -344,7 +357,7 @@ class SiglipEncoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
@@ -476,6 +489,10 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = SiglipVisionTransformer(
config,
quant_config,