[VLM][Model] TP support for ViTs (#7186)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
committed by
GitHub
parent
afd39a4511
commit
f97be32d1d
@@ -10,10 +10,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
from xformers import ops as xops
|
||||
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -81,7 +84,11 @@ class InternVisionEmbeddings(nn.Module):
|
||||
class InternAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -94,9 +101,13 @@ class InternAttention(nn.Module):
|
||||
f' {self.num_heads}).')
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(self.embed_dim,
|
||||
3 * self.embed_dim,
|
||||
bias=config.qkv_bias)
|
||||
self.qkv = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
@@ -104,25 +115,40 @@ class InternAttention(nn.Module):
|
||||
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
qkv, _ = self.qkv(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
|
||||
if self.qk_normalization:
|
||||
B_, H_, N_, D_ = q.shape
|
||||
q = self.q_norm.forward_native(q.transpose(1, 2).flatten(
|
||||
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||
k = self.k_norm.forward_native(k.transpose(1, 2).flatten(
|
||||
-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
|
||||
B_, N_, H_, D_ = q.shape
|
||||
q = self.q_norm.forward_native(q.flatten(-2,
|
||||
-1)).view(B_, N_, H_, D_)
|
||||
k = self.k_norm.forward_native(k.flatten(-2,
|
||||
-1)).view(B_, N_, H_, D_)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = xops.memory_efficient_attention_forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=self.scale,
|
||||
)
|
||||
x = x.view(B, N, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
x, _ = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -161,7 +187,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
|
||||
self.attn = InternAttention(config)
|
||||
self.attn = InternAttention(config, quant_config=quant_config)
|
||||
self.mlp = InternMLP(config, quant_config=quant_config)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
Reference in New Issue
Block a user