[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (#8061)
This commit is contained in:
@@ -9,7 +9,7 @@ import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionConfig
|
||||
from xformers import ops as xops
|
||||
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@@ -26,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
|
||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
# Since interpolation is applied, the image size need not be divisible
|
||||
@@ -219,7 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class SiglipAttention(nn.Module):
|
||||
class SiglipParallelAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -282,7 +288,7 @@ class SiglipAttention(nn.Module):
|
||||
out = out.view(batch_size, q_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class SiglipMLP(nn.Module):
|
||||
@@ -327,7 +333,14 @@ class SiglipEncoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.self_attn = SiglipAttention(config, quant_config=quant_config)
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = SiglipParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.self_attn = SiglipSdpaAttention(config)
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@@ -344,7 +357,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
@@ -476,6 +489,10 @@ class SiglipVisionModel(nn.Module):
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
|
||||
Reference in New Issue
Block a user