[VLM][Model] TP support for ViTs (#7186)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
committed by
GitHub
parent
afd39a4511
commit
f97be32d1d
@@ -6,8 +6,6 @@ import torch.nn as nn
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||||
|
|
||||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
|
||||||
|
|
||||||
from ..conftest import _ImageAssets, cleanup
|
from ..conftest import _ImageAssets, cleanup
|
||||||
|
|
||||||
pytestmark = pytest.mark.vlm
|
pytestmark = pytest.mark.vlm
|
||||||
@@ -49,6 +47,7 @@ def run_intern_vit_test(
|
|||||||
for pixel_value in pixel_values
|
for pixel_value in pixel_values
|
||||||
]
|
]
|
||||||
|
|
||||||
|
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||||
vllm_model = InternVisionModel(config)
|
vllm_model = InternVisionModel(config)
|
||||||
vllm_model.load_weights(hf_model.state_dict().items())
|
vllm_model.load_weights(hf_model.state_dict().items())
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,6 @@ import torch
|
|||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
|
|
||||||
IMG_START,
|
|
||||||
image_to_pixel_values)
|
|
||||||
from vllm.multimodal.utils import rescale_image_size
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
from vllm.utils import is_cpu
|
from vllm.utils import is_cpu
|
||||||
|
|
||||||
@@ -33,35 +30,6 @@ models = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class InternVLProcessor:
|
|
||||||
"""A simple processor for InternVL2 HF model which misses a processor."""
|
|
||||||
|
|
||||||
def __init__(self, hf_runner: HfRunner):
|
|
||||||
self.num_image_token = hf_runner.model.num_image_token
|
|
||||||
self.tokenizer = hf_runner.tokenizer
|
|
||||||
self.dtype = hf_runner.model.dtype
|
|
||||||
|
|
||||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
|
|
||||||
self.vision_config = self.config.vision_config
|
|
||||||
self.use_thumbnail = self.config.use_thumbnail
|
|
||||||
self.min_num = self.config.min_dynamic_patch
|
|
||||||
self.max_num = self.config.max_dynamic_patch
|
|
||||||
self.image_size = self.vision_config.image_size
|
|
||||||
|
|
||||||
def __call__(self, text: str, images: Image, **kwargs):
|
|
||||||
pixel_values = image_to_pixel_values(images, self.image_size,
|
|
||||||
self.min_num, self.max_num,
|
|
||||||
self.use_thumbnail).to(self.dtype)
|
|
||||||
num_patches_list = [pixel_values.shape[0]]
|
|
||||||
for num_patches in num_patches_list:
|
|
||||||
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
|
|
||||||
image_tokens = IMG_START + context_tokens + IMG_END
|
|
||||||
text = text.replace('<image>', image_tokens, 1)
|
|
||||||
prompt = self.tokenizer(text, return_tensors="pt")
|
|
||||||
prompt.update({"pixel_values": pixel_values})
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
|
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@@ -127,6 +95,37 @@ def run_test(
|
|||||||
# if we run HF first, the cuda initialization will be done and it
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
# will hurt multiprocessing backend with fork method (the default method).
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
|
class InternVLProcessor:
|
||||||
|
"""A simple processor for InternVL2 which misses a processor."""
|
||||||
|
|
||||||
|
def __init__(self, hf_runner: HfRunner):
|
||||||
|
self.num_image_token = hf_runner.model.num_image_token
|
||||||
|
self.tokenizer = hf_runner.tokenizer
|
||||||
|
self.dtype = hf_runner.model.dtype
|
||||||
|
|
||||||
|
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
|
||||||
|
self.vision_config = self.config.vision_config
|
||||||
|
self.use_thumbnail = self.config.use_thumbnail
|
||||||
|
self.min_num = self.config.min_dynamic_patch
|
||||||
|
self.max_num = self.config.max_dynamic_patch
|
||||||
|
self.image_size = self.vision_config.image_size
|
||||||
|
|
||||||
|
def __call__(self, text: str, images: Image, **kwargs):
|
||||||
|
from vllm.model_executor.models.internvl import (
|
||||||
|
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
|
||||||
|
pixel_values = image_to_pixel_values(
|
||||||
|
images, self.image_size, self.min_num, self.max_num,
|
||||||
|
self.use_thumbnail).to(self.dtype)
|
||||||
|
num_patches_list = [pixel_values.shape[0]]
|
||||||
|
for num_patches in num_patches_list:
|
||||||
|
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||||
|
* num_patches
|
||||||
|
image_tokens = IMG_START + context_tokens + IMG_END
|
||||||
|
text = text.replace('<image>', image_tokens, 1)
|
||||||
|
prompt = self.tokenizer(text, return_tensors="pt")
|
||||||
|
prompt.update({"pixel_values": pixel_values})
|
||||||
|
return prompt
|
||||||
|
|
||||||
# max_model_len should be greater than image_feature_size
|
# max_model_len should be greater than image_feature_size
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
|
|||||||
@@ -7,12 +7,14 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||||
from transformers.models.blip.modeling_blip import BlipAttention
|
from xformers import ops as xops
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import LLMInputs
|
from vllm.inputs import LLMInputs
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||||
@@ -154,6 +156,77 @@ class BlipVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BlipAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BlipVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
"embed_dim must be divisible by num_heads "
|
||||||
|
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads}).")
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.qkv = QKVParallelLinear(
|
||||||
|
self.embed_dim,
|
||||||
|
self.head_dim,
|
||||||
|
self.num_heads,
|
||||||
|
bias=config.qkv_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.projection = RowParallelLinear(
|
||||||
|
self.embed_dim,
|
||||||
|
self.embed_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads,
|
||||||
|
self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv_states, _ = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||||
|
query_states = query_states.view(bsz, tgt_len,
|
||||||
|
self.num_heads_per_partition,
|
||||||
|
self.head_dim)
|
||||||
|
key_states = key_states.view(bsz, tgt_len,
|
||||||
|
self.num_heads_per_partition,
|
||||||
|
self.head_dim)
|
||||||
|
value_states = value_states.view(bsz, tgt_len,
|
||||||
|
self.num_heads_per_partition,
|
||||||
|
self.head_dim)
|
||||||
|
|
||||||
|
out = xops.memory_efficient_attention_forward(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
p=self.dropout,
|
||||||
|
scale=self.scale)
|
||||||
|
out = out.view(bsz, tgt_len, -1)
|
||||||
|
attn_output, _ = self.projection(out)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
class BlipMLP(nn.Module):
|
class BlipMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -188,7 +261,7 @@ class BlipEncoderLayer(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.self_attn = BlipAttention(config)
|
self.self_attn = BlipAttention(config, quant_config=quant_config)
|
||||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.mlp = BlipMLP(config, quant_config=quant_config)
|
self.mlp = BlipMLP(config, quant_config=quant_config)
|
||||||
@@ -199,7 +272,7 @@ class BlipEncoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|||||||
@@ -714,8 +714,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
use_default_weight_loading = False
|
use_default_weight_loading = False
|
||||||
if "vision" in name:
|
if "vision" in name:
|
||||||
if self.vision_model is not None:
|
if self.vision_model is not None:
|
||||||
# We only do sharding for language model and
|
# BlipVisionModel does not need sharding
|
||||||
# not vision model for now.
|
|
||||||
use_default_weight_loading = True
|
use_default_weight_loading = True
|
||||||
else:
|
else:
|
||||||
for (param_name, weight_name,
|
for (param_name, weight_name,
|
||||||
|
|||||||
@@ -7,12 +7,14 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPVisionConfig
|
from transformers import CLIPVisionConfig
|
||||||
from transformers.models.clip.modeling_clip import CLIPAttention
|
from xformers import ops as xops
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import LLMInputs
|
from vllm.inputs import LLMInputs
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@@ -160,6 +162,78 @@ class CLIPVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: CLIPVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
"embed_dim must be divisible by num_heads "
|
||||||
|
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads}).")
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=self.embed_dim,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.num_heads,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
input_size=self.embed_dim,
|
||||||
|
output_size=self.embed_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads,
|
||||||
|
self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, tgt_len,
|
||||||
|
self.num_heads_per_partition,
|
||||||
|
self.head_dim)
|
||||||
|
key_states = key_states.view(bsz, tgt_len,
|
||||||
|
self.num_heads_per_partition,
|
||||||
|
self.head_dim)
|
||||||
|
value_states = value_states.view(bsz, tgt_len,
|
||||||
|
self.num_heads_per_partition,
|
||||||
|
self.head_dim)
|
||||||
|
|
||||||
|
out = xops.memory_efficient_attention_forward(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
p=self.dropout,
|
||||||
|
scale=self.scale)
|
||||||
|
out = out.view(bsz, tgt_len, -1)
|
||||||
|
attn_output, _ = self.out_proj(out)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
class CLIPMLP(nn.Module):
|
class CLIPMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -192,7 +266,7 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None):
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.self_attn = CLIPAttention(config)
|
self.self_attn = CLIPAttention(config, quant_config=quant_config)
|
||||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.mlp = CLIPMLP(config, quant_config=quant_config)
|
self.mlp = CLIPMLP(config, quant_config=quant_config)
|
||||||
@@ -204,7 +278,7 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -304,7 +378,15 @@ class CLIPVisionModel(nn.Module):
|
|||||||
def device(self):
|
def device(self):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||||
|
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
layer_count = len(self.vision_model.encoder.layers)
|
layer_count = len(self.vision_model.encoder.layers)
|
||||||
|
|
||||||
@@ -318,7 +400,16 @@ class CLIPVisionModel(nn.Module):
|
|||||||
if layer_idx >= layer_count:
|
if layer_idx >= layer_count:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
weight_loader = getattr(param, "weight_loader",
|
if weight_name not in name:
|
||||||
default_weight_loader)
|
continue
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
param = params_dict[name.replace(weight_name, param_name)]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@@ -10,10 +10,13 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from xformers import ops as xops
|
||||||
|
|
||||||
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@@ -81,7 +84,11 @@ class InternVisionEmbeddings(nn.Module):
|
|||||||
class InternAttention(nn.Module):
|
class InternAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@@ -94,9 +101,13 @@ class InternAttention(nn.Module):
|
|||||||
f' {self.num_heads}).')
|
f' {self.num_heads}).')
|
||||||
|
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
self.qkv = nn.Linear(self.embed_dim,
|
self.qkv = QKVParallelLinear(
|
||||||
3 * self.embed_dim,
|
self.embed_dim,
|
||||||
bias=config.qkv_bias)
|
self.head_dim,
|
||||||
|
self.num_heads,
|
||||||
|
bias=config.qkv_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
self.qk_normalization = config.qk_normalization
|
self.qk_normalization = config.qk_normalization
|
||||||
|
|
||||||
@@ -104,25 +115,40 @@ class InternAttention(nn.Module):
|
|||||||
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
self.proj = RowParallelLinear(
|
||||||
|
self.embed_dim,
|
||||||
|
self.embed_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
qkv, _ = self.qkv(x)
|
||||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.chunk(3, dim=-1)
|
||||||
q, k, v = qkv.unbind(0)
|
|
||||||
|
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||||
|
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||||
|
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||||
|
|
||||||
if self.qk_normalization:
|
if self.qk_normalization:
|
||||||
B_, H_, N_, D_ = q.shape
|
B_, N_, H_, D_ = q.shape
|
||||||
q = self.q_norm.forward_native(q.transpose(1, 2).flatten(
|
q = self.q_norm.forward_native(q.flatten(-2,
|
||||||
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
-1)).view(B_, N_, H_, D_)
|
||||||
k = self.k_norm.forward_native(k.transpose(1, 2).flatten(
|
k = self.k_norm.forward_native(k.flatten(-2,
|
||||||
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
-1)).view(B_, N_, H_, D_)
|
||||||
|
|
||||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
x = xops.memory_efficient_attention_forward(
|
||||||
x = x.transpose(1, 2).reshape(B, N, C)
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
scale=self.scale,
|
||||||
|
)
|
||||||
|
x = x.view(B, N, -1)
|
||||||
|
|
||||||
x = self.proj(x)
|
x, _ = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -161,7 +187,7 @@ class InternVisionEncoderLayer(nn.Module):
|
|||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.norm_type = config.norm_type
|
self.norm_type = config.norm_type
|
||||||
|
|
||||||
self.attn = InternAttention(config)
|
self.attn = InternAttention(config, quant_config=quant_config)
|
||||||
self.mlp = InternMLP(config, quant_config=quant_config)
|
self.mlp = InternMLP(config, quant_config=quant_config)
|
||||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
|
|||||||
@@ -145,7 +145,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
# TODO(ywang96): Port over SiglipVisionModel & TP
|
|
||||||
self.vision_tower = SiglipVisionModel(config.vision_config)
|
self.vision_tower = SiglipVisionModel(config.vision_config)
|
||||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
||||||
vision_hidden_size=config.vision_config.hidden_size,
|
vision_hidden_size=config.vision_config.hidden_size,
|
||||||
@@ -308,34 +307,27 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
if key_to_modify in name:
|
if key_to_modify in name:
|
||||||
name = name.replace(key_to_modify, new_key)
|
name = name.replace(key_to_modify, new_key)
|
||||||
use_default_weight_loading = False
|
use_default_weight_loading = False
|
||||||
if "vision" in name:
|
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||||
if self.vision_tower is not None:
|
if shard_name not in name:
|
||||||
# We only do sharding for language model and
|
continue
|
||||||
# not vision model for now.
|
name = name.replace(shard_name, param_name)
|
||||||
use_default_weight_loading = True
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
for (param_name, shard_name,
|
# lm_head is not used in vllm as it is tied with
|
||||||
shard_id) in stacked_params_mapping:
|
# embed_token. To prevent errors, skip loading
|
||||||
if shard_name not in name:
|
# lm_head.weight.
|
||||||
continue
|
if "lm_head.weight" in name:
|
||||||
name = name.replace(shard_name, param_name)
|
continue
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
use_default_weight_loading = True
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# lm_head is not used in vllm as it is tied with
|
|
||||||
# embed_token. To prevent errors, skip loading
|
|
||||||
# lm_head.weight.
|
|
||||||
if "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
use_default_weight_loading = True
|
|
||||||
|
|
||||||
if use_default_weight_loading:
|
if use_default_weight_loading:
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
|
|||||||
@@ -71,6 +71,23 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
|||||||
projection_dim=768)
|
projection_dim=768)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_img_processor(hf_config: PretrainedConfig):
|
||||||
|
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
||||||
|
layer_idx = hf_config.img_processor.get('layer_idx', -2)
|
||||||
|
|
||||||
|
# Initialize the CLIP only up to the required feature layer
|
||||||
|
if layer_idx < 0:
|
||||||
|
num_hidden_layers = clip_config.num_hidden_layers + \
|
||||||
|
layer_idx + 1
|
||||||
|
else:
|
||||||
|
num_hidden_layers = layer_idx + 1
|
||||||
|
|
||||||
|
img_processor = CLIPVisionModel(
|
||||||
|
clip_config, num_hidden_layers_override=num_hidden_layers)
|
||||||
|
|
||||||
|
return img_processor
|
||||||
|
|
||||||
|
|
||||||
class Phi3VImagePixelInputs(TypedDict):
|
class Phi3VImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
@@ -139,18 +156,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
|||||||
hidden_size = config.n_embd if hasattr(
|
hidden_size = config.n_embd if hasattr(
|
||||||
config, 'n_embd') else config.hidden_size
|
config, 'n_embd') else config.hidden_size
|
||||||
|
|
||||||
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
self.img_processor = _init_img_processor(config)
|
||||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
|
||||||
|
|
||||||
# Initialize the CLIP only up to the required feature layer
|
|
||||||
if self.layer_idx < 0:
|
|
||||||
num_hidden_layers = clip_config.num_hidden_layers + \
|
|
||||||
self.layer_idx + 1
|
|
||||||
else:
|
|
||||||
num_hidden_layers = self.layer_idx + 1
|
|
||||||
|
|
||||||
self.img_processor = CLIPVisionModel(
|
|
||||||
clip_config, num_hidden_layers_override=num_hidden_layers)
|
|
||||||
image_dim_out = config.img_processor['image_dim_out']
|
image_dim_out = config.img_processor['image_dim_out']
|
||||||
self.num_img_tokens = config.img_processor['num_img_tokens']
|
self.num_img_tokens = config.img_processor['num_img_tokens']
|
||||||
|
|
||||||
@@ -656,23 +663,27 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
|||||||
(".gate_up_proj", ".gate_proj", 0),
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
(".gate_up_proj", ".up_proj", 1),
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# TODO(ChristopherCho): This is a temporary fix to load
|
||||||
|
# the vision weights with CLIPVisionModel.load_weights()
|
||||||
|
vision_weights = []
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
# post_layernorm is not needed in CLIPVisionModel
|
# Skip loading the img_processor weights since they are
|
||||||
if "vision_model.post_layernorm" in name:
|
# loaded separately.
|
||||||
|
if "vision_embed_tokens.img_processor" in name:
|
||||||
|
vision_weights.append((name, loaded_weight))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||||
if key_to_modify in name:
|
if key_to_modify in name:
|
||||||
name = name.replace(key_to_modify, new_key)
|
name = name.replace(key_to_modify, new_key)
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
# We only do sharding for language model
|
|
||||||
# and not vision model for now.
|
|
||||||
if "vision_embed_tokens" in name and self.vision_embed_tokens:
|
|
||||||
continue
|
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
param = params_dict[name.replace(weight_name, param_name)]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@@ -686,3 +697,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
|||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# We use regex to extract the sub-module name
|
||||||
|
# from "model.vision_embed_tokens.img_processor.*"
|
||||||
|
vision_weights = [
|
||||||
|
(re.search(r"vision_embed_tokens\.img_processor\.(.*)",
|
||||||
|
n).group(1), w) for n, w in vision_weights
|
||||||
|
]
|
||||||
|
self.vision_embed_tokens.img_processor.load_weights(vision_weights)
|
||||||
|
|||||||
@@ -9,12 +9,10 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import SiglipVisionConfig
|
from transformers import SiglipVisionConfig
|
||||||
from transformers.models.siglip.modeling_siglip import SiglipAttention
|
from xformers import ops as xops
|
||||||
from vllm_flash_attn import flash_attn_func
|
|
||||||
from xformers.ops import memory_efficient_attention
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import LLMInputs
|
from vllm.inputs import LLMInputs
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@@ -221,9 +219,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Not used - kept for later when we TP the ViT
|
class SiglipAttention(nn.Module):
|
||||||
# TODO(ChristopherCho): Implement TP version of Attention
|
|
||||||
class SiglipTPAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -233,38 +229,30 @@ class SiglipTPAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
self.total_num_heads = config.num_attention_heads
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
if self.total_num_heads % tp_size != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of attention heads ({self.total_num_heads}) "
|
|
||||||
"must be divisible by the tensor model parallel size"
|
|
||||||
f" ({tp_size}).")
|
|
||||||
|
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
|
||||||
self.head_dim = self.embed_dim // self.total_num_heads
|
|
||||||
if self.head_dim * self.total_num_heads != self.embed_dim:
|
|
||||||
raise ValueError(f"embed_dim must be divisible by num_heads (got "
|
raise ValueError(f"embed_dim must be divisible by num_heads (got "
|
||||||
"`embed_dim`: {self.embed_dim} and `num_heads`:"
|
"`embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
f" {self.num_heads}).")
|
f" {self.num_heads}).")
|
||||||
self.qkv_size = self.num_heads * self.head_dim
|
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
self.dropout = config.attention_dropout
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size=self.embed_dim,
|
hidden_size=self.embed_dim,
|
||||||
head_size=self.head_dim,
|
head_size=self.head_dim,
|
||||||
total_num_heads=self.total_num_heads,
|
total_num_heads=self.num_heads,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.out_proj = RowParallelLinear(
|
self.out_proj = RowParallelLinear(
|
||||||
input_size=self.embed_dim,
|
input_size=self.embed_dim,
|
||||||
output_size=self.embed_dim,
|
output_size=self.embed_dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn_fn = self._basic_attention_forward
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -274,163 +262,29 @@ class SiglipTPAttention(nn.Module):
|
|||||||
batch_size, q_len, _ = hidden_states.size()
|
batch_size, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||||
query_states, key_states, value_states = qkv_states.split(
|
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||||
[self.qkv_size] * 3, dim=-1)
|
|
||||||
|
|
||||||
attn_output = self.attn_fn(
|
query_states = query_states.view(batch_size, q_len,
|
||||||
q=query_states,
|
self.num_heads_per_partition,
|
||||||
k=key_states,
|
self.head_dim)
|
||||||
v=value_states,
|
key_states = key_states.view(batch_size, q_len,
|
||||||
batch_size=batch_size,
|
self.num_heads_per_partition,
|
||||||
q_len=q_len,
|
self.head_dim)
|
||||||
)
|
value_states = value_states.view(batch_size, q_len,
|
||||||
|
self.num_heads_per_partition,
|
||||||
|
self.head_dim)
|
||||||
|
|
||||||
attn_output, _ = self.out_proj(attn_output)
|
out = xops.memory_efficient_attention_forward(query_states,
|
||||||
return attn_output
|
key_states,
|
||||||
|
value_states,
|
||||||
def _basic_attention_forward(self, q, k, v, batch_size, q_len):
|
p=self.dropout,
|
||||||
q = q.view(batch_size, q_len, self.num_heads,
|
scale=self.scale)
|
||||||
self.head_dim).transpose(1, 2)
|
out = out.view(batch_size, q_len, -1)
|
||||||
k = k.view(batch_size, q_len, self.num_heads,
|
attn_output, _ = self.out_proj(out)
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
v = v.view(batch_size, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
k_v_seq_len = k.shape[-2]
|
|
||||||
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
|
|
||||||
|
|
||||||
if attn_weights.size() != (
|
|
||||||
batch_size,
|
|
||||||
self.num_heads,
|
|
||||||
q_len,
|
|
||||||
k_v_seq_len,
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Attention weights should be of size "
|
|
||||||
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}")
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights,
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.float32).to(q.dtype)
|
|
||||||
attn_weights = nn.functional.dropout(attn_weights,
|
|
||||||
p=self.dropout,
|
|
||||||
training=self.training)
|
|
||||||
attn_output = torch.matmul(attn_weights, v)
|
|
||||||
|
|
||||||
if attn_output.size() != (
|
|
||||||
batch_size,
|
|
||||||
self.num_heads,
|
|
||||||
q_len,
|
|
||||||
self.head_dim,
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"`attn_output` should be of size "
|
|
||||||
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}")
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Not used - kept for later when we TP the ViT
|
|
||||||
# TODO(ChristopherCho): flash_attn_func is not working properly.
|
|
||||||
# It constantly throws a CUDA error.
|
|
||||||
class SiglipFlashAttention2(SiglipTPAttention):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.attn_fn = self._flash_attention_forward
|
|
||||||
|
|
||||||
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
|
|
||||||
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
|
|
||||||
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
|
|
||||||
**kwargs):
|
|
||||||
"""Implements the multihead softmax attention.
|
|
||||||
Arguments
|
|
||||||
---------
|
|
||||||
q, k, v: The tensor containing the
|
|
||||||
query, key, and value. (B, S, H, D)
|
|
||||||
"""
|
|
||||||
|
|
||||||
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
|
|
||||||
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
|
|
||||||
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
attn_output = flash_attn_func(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
dropout_p=self.dropout,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.reshape(batch_size, q_len,
|
|
||||||
self.embed_dim).contiguous()
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Not used - kept for later when we TP the ViT
|
|
||||||
class SiglipSdpaAttention(SiglipTPAttention):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.is_causal = False
|
|
||||||
self.attn_fn = self._sdpa_attention_forward
|
|
||||||
|
|
||||||
def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
|
|
||||||
q = q.view(batch_size, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
k = k.view(batch_size, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
v = v.view(batch_size, q_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Not used - kept for later when we TP the ViT
|
|
||||||
class SiglipxFormersAttention(SiglipTPAttention):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.attn_fn = self._xformers_attention_forward
|
|
||||||
|
|
||||||
def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
|
|
||||||
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
|
|
||||||
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
|
|
||||||
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
attn_output = memory_efficient_attention(q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
p=0.0,
|
|
||||||
scale=self.scale)
|
|
||||||
attn_output = attn_output.reshape(batch_size, q_len,
|
|
||||||
self.embed_dim).contiguous()
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: Not used - kept for later when we TP the ViT
|
|
||||||
SIGLIP_ATTENTION_CLASSES = {
|
|
||||||
"eager": SiglipTPAttention,
|
|
||||||
"flash_attention_2": SiglipFlashAttention2,
|
|
||||||
"sdpa": SiglipSdpaAttention,
|
|
||||||
"xformers": SiglipxFormersAttention,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SiglipMLP(nn.Module):
|
class SiglipMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -473,8 +327,7 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
|
||||||
# TODO(ChristopherCho): use TP'ed Attention block
|
self.self_attn = SiglipAttention(config, quant_config=quant_config)
|
||||||
self.self_attn = SiglipAttention(config)
|
|
||||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
self.mlp = SiglipMLP(
|
self.mlp = SiglipMLP(
|
||||||
@@ -491,7 +344,7 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user