[VLM][Model] TP support for ViTs (#7186)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Jungho Christopher Cho
2024-08-31 00:19:27 +09:00
committed by GitHub
parent afd39a4511
commit f97be32d1d
9 changed files with 336 additions and 285 deletions

View File

@@ -6,8 +6,6 @@ import torch.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.model_executor.models.intern_vit import InternVisionModel
from ..conftest import _ImageAssets, cleanup from ..conftest import _ImageAssets, cleanup
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
@@ -49,6 +47,7 @@ def run_intern_vit_test(
for pixel_value in pixel_values for pixel_value in pixel_values
] ]
from vllm.model_executor.models.intern_vit import InternVisionModel
vllm_model = InternVisionModel(config) vllm_model = InternVisionModel(config)
vllm_model.load_weights(hf_model.state_dict().items()) vllm_model.load_weights(hf_model.state_dict().items())

View File

@@ -6,9 +6,6 @@ import torch
from PIL.Image import Image from PIL.Image import Image
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
IMG_START,
image_to_pixel_values)
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
from vllm.utils import is_cpu from vllm.utils import is_cpu
@@ -33,35 +30,6 @@ models = [
] ]
class InternVLProcessor:
"""A simple processor for InternVL2 HF model which misses a processor."""
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Image, **kwargs):
pixel_values = image_to_pixel_values(images, self.image_size,
self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
def generate( def generate(
self, self,
@@ -127,6 +95,37 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
class InternVLProcessor:
"""A simple processor for InternVL2 which misses a processor."""
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Image, **kwargs):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
pixel_values = image_to_pixel_values(
images, self.image_size, self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt
# max_model_len should be greater than image_feature_size # max_model_len should be greater than image_feature_size
with vllm_runner(model, with vllm_runner(model,
max_model_len=4096, max_model_len=4096,

View File

@@ -7,12 +7,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention from xformers import ops as xops
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
@@ -154,6 +156,77 @@ class BlipVisionEmbeddings(nn.Module):
return embeddings return embeddings
class BlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)
return attn_output
class BlipMLP(nn.Module): class BlipMLP(nn.Module):
def __init__(self, def __init__(self,
@@ -188,7 +261,7 @@ class BlipEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.self_attn = BlipAttention(config) self.self_attn = BlipAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config) self.mlp = BlipMLP(config, quant_config=quant_config)
@@ -199,7 +272,7 @@ class BlipEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states) hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states

View File

@@ -714,8 +714,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
use_default_weight_loading = False use_default_weight_loading = False
if "vision" in name: if "vision" in name:
if self.vision_model is not None: if self.vision_model is not None:
# We only do sharding for language model and # BlipVisionModel does not need sharding
# not vision model for now.
use_default_weight_loading = True use_default_weight_loading = True
else: else:
for (param_name, weight_name, for (param_name, weight_name,

View File

@@ -7,12 +7,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention from xformers import ops as xops
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -160,6 +162,78 @@ class CLIPVisionEmbeddings(nn.Module):
return embeddings return embeddings
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output
class CLIPMLP(nn.Module): class CLIPMLP(nn.Module):
def __init__(self, def __init__(self,
@@ -192,7 +266,7 @@ class CLIPEncoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.self_attn = CLIPAttention(config) self.self_attn = CLIPAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config) self.mlp = CLIPMLP(config, quant_config=quant_config)
@@ -204,7 +278,7 @@ class CLIPEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states) hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
@@ -304,7 +378,15 @@ class CLIPVisionModel(nn.Module):
def device(self): def device(self):
return next(self.parameters()).device return next(self.parameters()).device
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
@@ -318,6 +400,15 @@ class CLIPVisionModel(nn.Module):
if layer_idx >= layer_count: if layer_idx >= layer_count:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@@ -10,10 +10,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from xformers import ops as xops
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -81,7 +84,11 @@ class InternVisionEmbeddings(nn.Module):
class InternAttention(nn.Module): class InternAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig): def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@@ -94,9 +101,13 @@ class InternAttention(nn.Module):
f' {self.num_heads}).') f' {self.num_heads}).')
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim, self.qkv = QKVParallelLinear(
3 * self.embed_dim, self.embed_dim,
bias=config.qkv_bias) self.head_dim,
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.qk_normalization = config.qk_normalization self.qk_normalization = config.qk_normalization
@@ -104,25 +115,40 @@ class InternAttention(nn.Module):
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.proj = nn.Linear(self.embed_dim, self.embed_dim) self.proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def forward(self, x): def forward(self, x):
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, qkv, _ = self.qkv(x)
C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.chunk(3, dim=-1)
q, k, v = qkv.unbind(0)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
if self.qk_normalization: if self.qk_normalization:
B_, H_, N_, D_ = q.shape B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.transpose(1, 2).flatten( q = self.q_norm.forward_native(q.flatten(-2,
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) -1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.transpose(1, 2).flatten( k = self.k_norm.forward_native(k.flatten(-2,
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) -1)).view(B_, N_, H_, D_)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) x = xops.memory_efficient_attention_forward(
x = x.transpose(1, 2).reshape(B, N, C) q,
k,
v,
scale=self.scale,
)
x = x.view(B, N, -1)
x = self.proj(x) x, _ = self.proj(x)
return x return x
@@ -161,7 +187,7 @@ class InternVisionEncoderLayer(nn.Module):
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type self.norm_type = config.norm_type
self.attn = InternAttention(config) self.attn = InternAttention(config, quant_config=quant_config)
self.mlp = InternMLP(config, quant_config=quant_config) self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)

View File

@@ -145,7 +145,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# TODO(ywang96): Port over SiglipVisionModel & TP
self.vision_tower = SiglipVisionModel(config.vision_config) self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector( self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
@@ -308,14 +307,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if key_to_modify in name: if key_to_modify in name:
name = name.replace(key_to_modify, new_key) name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False use_default_weight_loading = False
if "vision" in name: for (param_name, shard_name, shard_id) in stacked_params_mapping:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, shard_name,
shard_id) in stacked_params_mapping:
if shard_name not in name: if shard_name not in name:
continue continue
name = name.replace(shard_name, param_name) name = name.replace(shard_name, param_name)

View File

@@ -71,6 +71,23 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim=768) projection_dim=768)
def _init_img_processor(hf_config: PretrainedConfig):
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
layer_idx = hf_config.img_processor.get('layer_idx', -2)
# Initialize the CLIP only up to the required feature layer
if layer_idx < 0:
num_hidden_layers = clip_config.num_hidden_layers + \
layer_idx + 1
else:
num_hidden_layers = layer_idx + 1
img_processor = CLIPVisionModel(
clip_config, num_hidden_layers_override=num_hidden_layers)
return img_processor
class Phi3VImagePixelInputs(TypedDict): class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, List[torch.Tensor]]
@@ -139,18 +156,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
hidden_size = config.n_embd if hasattr( hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size config, 'n_embd') else config.hidden_size
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG self.img_processor = _init_img_processor(config)
self.layer_idx = config.img_processor.get('layer_idx', -2)
# Initialize the CLIP only up to the required feature layer
if self.layer_idx < 0:
num_hidden_layers = clip_config.num_hidden_layers + \
self.layer_idx + 1
else:
num_hidden_layers = self.layer_idx + 1
self.img_processor = CLIPVisionModel(
clip_config, num_hidden_layers_override=num_hidden_layers)
image_dim_out = config.img_processor['image_dim_out'] image_dim_out = config.img_processor['image_dim_out']
self.num_img_tokens = config.img_processor['num_img_tokens'] self.num_img_tokens = config.img_processor['num_img_tokens']
@@ -656,23 +663,27 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
(".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
# TODO(ChristopherCho): This is a temporary fix to load
# the vision weights with CLIPVisionModel.load_weights()
vision_weights = []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
# post_layernorm is not needed in CLIPVisionModel # Skip loading the img_processor weights since they are
if "vision_model.post_layernorm" in name: # loaded separately.
if "vision_embed_tokens.img_processor" in name:
vision_weights.append((name, loaded_weight))
continue continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name: if key_to_modify in name:
name = name.replace(key_to_modify, new_key) name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# We only do sharding for language model
# and not vision model for now.
if "vision_embed_tokens" in name and self.vision_embed_tokens:
continue
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
@@ -686,3 +697,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# We use regex to extract the sub-module name
# from "model.vision_embed_tokens.img_processor.*"
vision_weights = [
(re.search(r"vision_embed_tokens\.img_processor\.(.*)",
n).group(1), w) for n, w in vision_weights
]
self.vision_embed_tokens.img_processor.load_weights(vision_weights)

View File

@@ -9,12 +9,10 @@ import torch
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import SiglipVisionConfig from transformers import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipAttention from xformers import ops as xops
from vllm_flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -221,9 +219,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings return embeddings
# NOTE: Not used - kept for later when we TP the ViT class SiglipAttention(nn.Module):
# TODO(ChristopherCho): Implement TP version of Attention
class SiglipTPAttention(nn.Module):
def __init__( def __init__(
self, self,
@@ -233,38 +229,30 @@ class SiglipTPAttention(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size() self.head_dim = self.embed_dim // self.num_heads
self.total_num_heads = config.num_attention_heads if self.head_dim * self.num_heads != self.embed_dim:
if self.total_num_heads % tp_size != 0:
raise ValueError(
f"Number of attention heads ({self.total_num_heads}) "
"must be divisible by the tensor model parallel size"
f" ({tp_size}).")
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.embed_dim // self.total_num_heads
if self.head_dim * self.total_num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got " raise ValueError(f"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:" "`embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).") f" {self.num_heads}).")
self.qkv_size = self.num_heads * self.head_dim
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim, hidden_size=self.embed_dim,
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.total_num_heads, total_num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
input_size=self.embed_dim, input_size=self.embed_dim,
output_size=self.embed_dim, output_size=self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
) )
self.attn_fn = self._basic_attention_forward self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def forward( def forward(
self, self,
@@ -274,163 +262,29 @@ class SiglipTPAttention(nn.Module):
batch_size, q_len, _ = hidden_states.size() batch_size, q_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states) qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.split( query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
[self.qkv_size] * 3, dim=-1)
attn_output = self.attn_fn( query_states = query_states.view(batch_size, q_len,
q=query_states, self.num_heads_per_partition,
k=key_states, self.head_dim)
v=value_states, key_states = key_states.view(batch_size, q_len,
batch_size=batch_size, self.num_heads_per_partition,
q_len=q_len, self.head_dim)
) value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
attn_output, _ = self.out_proj(attn_output) out = xops.memory_efficient_attention_forward(query_states,
return attn_output key_states,
value_states,
def _basic_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k_v_seq_len = k.shape[-2]
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
if attn_weights.size() != (
batch_size,
self.num_heads,
q_len,
k_v_seq_len,
):
raise ValueError(
"Attention weights should be of size "
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}")
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(q.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.dropout, p=self.dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (
batch_size,
self.num_heads,
q_len,
self.head_dim,
):
raise ValueError(
"`attn_output` should be of size "
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): flash_attn_func is not working properly.
# It constantly throws a CUDA error.
class SiglipFlashAttention2(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._flash_attention_forward
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
**kwargs):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the
query, key, and value. (B, S, H, D)
"""
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout,
causal=False,
)
attn_output = attn_output.reshape(batch_size, q_len,
self.embed_dim).contiguous()
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipSdpaAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
self.attn_fn = self._sdpa_attention_forward
def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipxFormersAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._xformers_attention_forward
def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = memory_efficient_attention(q,
k,
v,
p=0.0,
scale=self.scale) scale=self.scale)
attn_output = attn_output.reshape(batch_size, q_len, out = out.view(batch_size, q_len, -1)
self.embed_dim).contiguous() attn_output, _ = self.out_proj(out)
return attn_output return attn_output
# NOTE: Not used - kept for later when we TP the ViT
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipTPAttention,
"flash_attention_2": SiglipFlashAttention2,
"sdpa": SiglipSdpaAttention,
"xformers": SiglipxFormersAttention,
}
class SiglipMLP(nn.Module): class SiglipMLP(nn.Module):
def __init__( def __init__(
@@ -473,8 +327,7 @@ class SiglipEncoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
# TODO(ChristopherCho): use TP'ed Attention block self.self_attn = SiglipAttention(config, quant_config=quant_config)
self.self_attn = SiglipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
@@ -491,7 +344,7 @@ class SiglipEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states) hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states