[Model] Support NVLM-D and fix QK Norm in InternViT (#9045)

Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Cyrus Leung
2024-10-07 19:55:12 +08:00
committed by GitHub
parent f19da64871
commit 151ef4efd2
12 changed files with 518 additions and 236 deletions

View File

@@ -4,6 +4,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import Iterable, Optional, Tuple
import torch
@@ -11,7 +12,10 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -54,7 +58,7 @@ class InternVisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_positions, self.embed_dim))
def _get_pos_embed(self, pos_embed, H, W):
def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
target_dtype = pos_embed.dtype
pos_embed = pos_embed.float().reshape(
1, self.image_size // self.patch_size,
@@ -63,9 +67,21 @@ class InternVisionEmbeddings(nn.Module):
size=(H, W),
mode='bicubic',
align_corners=False)
pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype)
return pos_embed
return pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype)
def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
position_embedding = self.position_embedding
if self.num_patches == H * W:
return position_embedding
return torch.cat(
[
position_embedding[:, :1, :],
self._get_pos_embed(position_embedding[:, 1:, :], H, W),
],
dim=1,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
@@ -76,12 +92,7 @@ class InternVisionEmbeddings(nn.Module):
class_embeds = self.class_embedding.expand(batch_size, 1,
-1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embedding = torch.cat([
self.position_embedding[:, :1, :],
self._get_pos_embed(self.position_embedding[:, 1:, :], height,
width)
],
dim=1)
position_embedding = self._get_position_embedding(height, width)
embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings
@@ -93,8 +104,11 @@ class InternParallelAttention(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
@@ -105,11 +119,19 @@ class InternParallelAttention(nn.Module):
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
self.tp_size)
self.scale = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
num_dummy_heads + self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
@@ -117,34 +139,44 @@ class InternParallelAttention(nn.Module):
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = RowParallelLinear(
self.embed_dim,
self.dummy_dim,
self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(self, x):
B, N, C = x.shape
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)
@@ -155,8 +187,14 @@ class InternParallelAttention(nn.Module):
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig):
def __init__(
self,
config: PretrainedConfig,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
@@ -167,20 +205,27 @@ class InternSdpaAttention(nn.Module):
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim,
3 * self.dummy_dim,
bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
@@ -233,22 +278,23 @@ class InternMLP(nn.Module):
class InternVisionEncoderLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.attn = InternParallelAttention(config,
quant_config=quant_config)
else:
self.attn = InternSdpaAttention(config)
self.attn = self._init_attn(config,
quant_config,
num_dummy_heads=num_dummy_heads)
self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
@@ -260,6 +306,24 @@ class InternVisionEncoderLayer(nn.Module):
self.ls2 = nn.Parameter(config.initializer_factor *
torch.ones(self.embed_dim))
def _init_attn(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
num_dummy_heads: int,
):
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
def forward(
self,
hidden_states: torch.Tensor,
@@ -275,19 +339,27 @@ class InternVisionEncoderLayer(nn.Module):
class InternVisionEncoder(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
InternVisionEncoderLayer(config=config, quant_config=quant_config)
InternVisionEncoderLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads)
for _ in range(num_hidden_layers)
])
@@ -302,35 +374,25 @@ class InternVisionEncoder(nn.Module):
class InternVisionModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
):
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)
def resize_pos_embeddings(self, old_size, new_size, patch_size):
pos_emb = self.embeddings.position_embedding
_, num_positions, embed_dim = pos_emb.shape
cls_emb = pos_emb[:, :1, :]
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size,
old_size // patch_size,
-1).permute(0, 3, 1, 2)
pos_emb = F.interpolate(pos_emb.float(),
size=new_size // patch_size,
mode='bicubic',
align_corners=False)
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim,
-1).permute(0, 2, 1)
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
self.embeddings.position_embedding = nn.Parameter(pos_emb)
self.embeddings.image_size = new_size
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
)
def get_input_embeddings(self):
return self.embeddings