[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (#8061)
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from xformers import ops as xops
|
||||
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@@ -22,6 +22,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
USE_XFORMERS_OPS = True
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
@@ -162,7 +168,7 @@ class CLIPVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
class CLIPParallelAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
@@ -231,7 +237,7 @@ class CLIPAttention(nn.Module):
|
||||
out = out.view(bsz, tgt_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
@@ -266,7 +272,13 @@ class CLIPEncoderLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = CLIPAttention(config, quant_config=quant_config)
|
||||
num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.self_attn = CLIPParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.self_attn = CLIPSdpaAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = CLIPMLP(config, quant_config=quant_config)
|
||||
@@ -278,7 +290,7 @@ class CLIPEncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
@@ -365,6 +377,10 @@ class CLIPVisionModel(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
|
||||
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
@@ -386,7 +402,7 @@ class CLIPVisionModel(nn.Module):
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
] if self.shard_weight else []
|
||||
params_dict = dict(self.named_parameters())
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user