[VLM][Model] TP support for ViTs (#7186)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Jungho Christopher Cho
2024-08-31 00:19:27 +09:00
committed by GitHub
parent afd39a4511
commit f97be32d1d
9 changed files with 336 additions and 285 deletions

View File

@@ -9,12 +9,10 @@ import torch
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipAttention
from vllm_flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention
from xformers import ops as xops
from vllm.config import ModelConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -221,9 +219,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): Implement TP version of Attention
class SiglipTPAttention(nn.Module):
class SiglipAttention(nn.Module):
def __init__(
self,
@@ -233,38 +229,30 @@ class SiglipTPAttention(nn.Module):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
if self.total_num_heads % tp_size != 0:
raise ValueError(
f"Number of attention heads ({self.total_num_heads}) "
"must be divisible by the tensor model parallel size"
f" ({tp_size}).")
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.embed_dim // self.total_num_heads
if self.head_dim * self.total_num_heads != self.embed_dim:
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.qkv_size = self.num_heads * self.head_dim
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_heads=self.num_heads,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
)
self.attn_fn = self._basic_attention_forward
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def forward(
self,
@@ -274,163 +262,29 @@ class SiglipTPAttention(nn.Module):
batch_size, q_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.split(
[self.qkv_size] * 3, dim=-1)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
attn_output = self.attn_fn(
q=query_states,
k=key_states,
v=value_states,
batch_size=batch_size,
q_len=q_len,
)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
attn_output, _ = self.out_proj(attn_output)
return attn_output
def _basic_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k_v_seq_len = k.shape[-2]
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
if attn_weights.size() != (
batch_size,
self.num_heads,
q_len,
k_v_seq_len,
):
raise ValueError(
"Attention weights should be of size "
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}")
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(q.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (
batch_size,
self.num_heads,
q_len,
self.head_dim,
):
raise ValueError(
"`attn_output` should be of size "
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): flash_attn_func is not working properly.
# It constantly throws a CUDA error.
class SiglipFlashAttention2(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._flash_attention_forward
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
**kwargs):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the
query, key, and value. (B, S, H, D)
"""
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout,
causal=False,
)
attn_output = attn_output.reshape(batch_size, q_len,
self.embed_dim).contiguous()
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipSdpaAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
self.attn_fn = self._sdpa_attention_forward
def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipxFormersAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._xformers_attention_forward
def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = memory_efficient_attention(q,
k,
v,
p=0.0,
scale=self.scale)
attn_output = attn_output.reshape(batch_size, q_len,
self.embed_dim).contiguous()
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipTPAttention,
"flash_attention_2": SiglipFlashAttention2,
"sdpa": SiglipSdpaAttention,
"xformers": SiglipxFormersAttention,
}
class SiglipMLP(nn.Module):
def __init__(
@@ -473,8 +327,7 @@ class SiglipEncoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.hidden_size
# TODO(ChristopherCho): use TP'ed Attention block
self.self_attn = SiglipAttention(config)
self.self_attn = SiglipAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@@ -491,7 +344,7 @@ class SiglipEncoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states