[VLM][Model] TP support for ViTs (#7186)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
committed by
GitHub
parent
afd39a4511
commit
f97be32d1d
@@ -9,12 +9,10 @@ import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionConfig
|
||||
from transformers.models.siglip.modeling_siglip import SiglipAttention
|
||||
from vllm_flash_attn import flash_attn_func
|
||||
from xformers.ops import memory_efficient_attention
|
||||
from xformers import ops as xops
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@@ -221,9 +219,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
# TODO(ChristopherCho): Implement TP version of Attention
|
||||
class SiglipTPAttention(nn.Module):
|
||||
class SiglipAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -233,38 +229,30 @@ class SiglipTPAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
if self.total_num_heads % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"Number of attention heads ({self.total_num_heads}) "
|
||||
"must be divisible by the tensor model parallel size"
|
||||
f" ({tp_size}).")
|
||||
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.head_dim = self.embed_dim // self.total_num_heads
|
||||
if self.head_dim * self.total_num_heads != self.embed_dim:
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(f"embed_dim must be divisible by num_heads (got "
|
||||
"`embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads}).")
|
||||
self.qkv_size = self.num_heads * self.head_dim
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.embed_dim,
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.attn_fn = self._basic_attention_forward
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -274,163 +262,29 @@ class SiglipTPAttention(nn.Module):
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.split(
|
||||
[self.qkv_size] * 3, dim=-1)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
attn_output = self.attn_fn(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
batch_size=batch_size,
|
||||
q_len=q_len,
|
||||
)
|
||||
query_states = query_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
key_states = key_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
value_states = value_states.view(batch_size, q_len,
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim)
|
||||
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
def _basic_attention_forward(self, q, k, v, batch_size, q_len):
|
||||
q = q.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
|
||||
k_v_seq_len = k.shape[-2]
|
||||
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
|
||||
|
||||
if attn_weights.size() != (
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
q_len,
|
||||
k_v_seq_len,
|
||||
):
|
||||
raise ValueError(
|
||||
"Attention weights should be of size "
|
||||
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||
f" {attn_weights.size()}")
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights,
|
||||
dim=-1,
|
||||
dtype=torch.float32).to(q.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, v)
|
||||
|
||||
if attn_output.size() != (
|
||||
batch_size,
|
||||
self.num_heads,
|
||||
q_len,
|
||||
self.head_dim,
|
||||
):
|
||||
raise ValueError(
|
||||
"`attn_output` should be of size "
|
||||
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}")
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||
out = xops.memory_efficient_attention_forward(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
p=self.dropout,
|
||||
scale=self.scale)
|
||||
out = out.view(batch_size, q_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
# TODO(ChristopherCho): flash_attn_func is not working properly.
|
||||
# It constantly throws a CUDA error.
|
||||
class SiglipFlashAttention2(SiglipTPAttention):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attn_fn = self._flash_attention_forward
|
||||
|
||||
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
|
||||
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
|
||||
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
|
||||
**kwargs):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q, k, v: The tensor containing the
|
||||
query, key, and value. (B, S, H, D)
|
||||
"""
|
||||
|
||||
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.dropout,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, q_len,
|
||||
self.embed_dim).contiguous()
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
class SiglipSdpaAttention(SiglipTPAttention):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_causal = False
|
||||
self.attn_fn = self._sdpa_attention_forward
|
||||
|
||||
def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
|
||||
q = q.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, q_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
class SiglipxFormersAttention(SiglipTPAttention):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attn_fn = self._xformers_attention_forward
|
||||
|
||||
def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
|
||||
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
attn_output = attn_output.reshape(batch_size, q_len,
|
||||
self.embed_dim).contiguous()
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE: Not used - kept for later when we TP the ViT
|
||||
SIGLIP_ATTENTION_CLASSES = {
|
||||
"eager": SiglipTPAttention,
|
||||
"flash_attention_2": SiglipFlashAttention2,
|
||||
"sdpa": SiglipSdpaAttention,
|
||||
"xformers": SiglipxFormersAttention,
|
||||
}
|
||||
|
||||
|
||||
class SiglipMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -473,8 +327,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# TODO(ChristopherCho): use TP'ed Attention block
|
||||
self.self_attn = SiglipAttention(config)
|
||||
self.self_attn = SiglipAttention(config, quant_config=quant_config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@@ -491,7 +344,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
Reference in New Issue
Block a user