549 lines
19 KiB
Python
549 lines
19 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""This is basically a copy from perception_models/core/vision_encoder/pe.py"""
|
|
|
|
from collections.abc import Callable
|
|
from functools import partial
|
|
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.layers.activation import get_act_fn
|
|
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
|
|
from vllm.model_executor.layers.conv import Conv2dLayer
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
|
|
from .step3_vl import Step3VLForConditionalGeneration
|
|
from .utils import WeightsMapper, init_vllm_registered_model, maybe_prefix
|
|
from .vision import run_dp_sharded_vision_model
|
|
|
|
_DEFAULT_NORM_LAYER = partial(nn.LayerNorm, eps=1e-5)
|
|
|
|
|
|
def rotate_half(x):
|
|
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
|
x1, x2 = x.unbind(dim=-1)
|
|
x = torch.stack((-x2, x1), dim=-1)
|
|
return rearrange(x, "... d r -> ... (d r)")
|
|
|
|
|
|
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
|
dtype = t.dtype
|
|
|
|
if t.ndim == 3:
|
|
seq_len = t.shape[seq_dim]
|
|
freqs = freqs[-seq_len:]
|
|
|
|
rot_dim = freqs.shape[-1]
|
|
end_index = start_index + rot_dim
|
|
|
|
assert rot_dim <= t.shape[-1], (
|
|
"feature dimension {} is not of sufficient size to rotate in all the "
|
|
"positions {}".format(t.shape[-1], rot_dim)
|
|
)
|
|
|
|
t_left, t, t_right = (
|
|
t[..., :start_index],
|
|
t[..., start_index:end_index],
|
|
t[..., end_index:],
|
|
)
|
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
|
out = torch.cat((t_left, t, t_right), dim=-1)
|
|
|
|
return out.type(dtype)
|
|
|
|
|
|
class PerceptionEncoderRope2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
max_grid_height: int,
|
|
max_grid_width: int,
|
|
use_cls_token: bool = False,
|
|
theta=10000,
|
|
max_freq=10,
|
|
num_freqs=1,
|
|
theta_rescale_factor=1.0,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.max_grid_height = max_grid_height
|
|
self.max_grid_width = max_grid_width
|
|
self.use_cls_token = use_cls_token
|
|
self.theta = theta * theta_rescale_factor ** (dim / (dim - 2))
|
|
self.max_freq = max_freq
|
|
self.num_freqs = num_freqs
|
|
cache = self._compute_2d_freqs()
|
|
self.register_buffer("freqs_cache", cache, persistent=False)
|
|
|
|
def _compute_inv_freq(self, base: int | float, dim: int) -> torch.Tensor:
|
|
freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
return freqs
|
|
|
|
def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
|
|
freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype), inv_freq)
|
|
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
|
return freqs
|
|
|
|
def _compute_2d_freqs(self) -> torch.Tensor:
|
|
grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
|
|
grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
|
|
if self.use_cls_token:
|
|
grid_h_range += 1
|
|
grid_w_range += 1
|
|
inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
|
|
freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
|
|
self.max_grid_height, self.max_grid_width, -1
|
|
)
|
|
freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
|
|
self.max_grid_height, self.max_grid_width, -1
|
|
)
|
|
freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
|
|
self.max_grid_height * self.max_grid_width, -1
|
|
)
|
|
if self.use_cls_token:
|
|
freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
|
|
freqs = freqs[None, None, ...]
|
|
return freqs
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, grid_hw: tuple[int, int]):
|
|
if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
|
|
rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
|
|
cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
|
|
positions = (rows * self.max_grid_width + cols).reshape(-1).to(torch.long)
|
|
if self.use_cls_token:
|
|
positions = torch.cat(
|
|
[torch.zeros(1, device=q.device), positions + 1], dim=0
|
|
)
|
|
positions = positions.to(torch.long)
|
|
freqs = self.freqs_cache.index_select(2, positions)
|
|
else:
|
|
freqs = self.freqs_cache
|
|
q = apply_rotary_emb(freqs, q)
|
|
k = apply_rotary_emb(freqs, k)
|
|
return q, k
|
|
|
|
|
|
class PerceptionEncoderLayerScale(nn.Module):
|
|
def __init__(self, dim, init_values=1e-5, inplace=False):
|
|
super().__init__()
|
|
self.inplace = inplace
|
|
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
def forward(self, x):
|
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
|
|
|
|
class PerceptionEncoderMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
act_layer: Callable[[], nn.Module],
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.fc1 = ColumnParallelLinear(
|
|
input_dim,
|
|
hidden_dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fc1",
|
|
disable_tp=use_data_parallel,
|
|
)
|
|
self.activation = act_layer
|
|
self.fc2 = RowParallelLinear(
|
|
hidden_dim,
|
|
input_dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fc2",
|
|
disable_tp=use_data_parallel,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x, _ = self.fc1(x)
|
|
x = self.activation(x)
|
|
x, _ = self.fc2(x)
|
|
return x
|
|
|
|
|
|
class PerceptionEncoderVisionAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
max_grid_height: int,
|
|
max_grid_width: int,
|
|
use_cls_token: bool = False,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.total_num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
|
assert self.total_num_heads % tp_size == 0, (
|
|
"embed_dim must be divisible by num_heads"
|
|
)
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
embed_dim,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
disable_tp=use_data_parallel,
|
|
)
|
|
self.out_proj = RowParallelLinear(
|
|
embed_dim,
|
|
embed_dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.out_proj",
|
|
disable_tp=use_data_parallel,
|
|
)
|
|
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
|
self.rope = PerceptionEncoderRope2D(
|
|
dim=self.head_dim,
|
|
max_grid_height=max_grid_height,
|
|
max_grid_width=max_grid_width,
|
|
use_cls_token=use_cls_token,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
|
|
bsz, seq_len, _ = x.shape
|
|
qkv, _ = self.qkv_proj(x)
|
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
|
|
q = q.view(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
k = k.view(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
q, k = self.rope(q, k, grid_hw=grid_hw)
|
|
q = q.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads * self.head_dim)
|
|
k = k.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads * self.head_dim)
|
|
|
|
attn_output = self.attn(q, k, v)
|
|
attn_output, _ = self.out_proj(attn_output)
|
|
return attn_output
|
|
|
|
|
|
class PerceptionEncoderVisionBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
n_head: int,
|
|
max_grid_height: int,
|
|
max_grid_width: int,
|
|
mlp_ratio: float = 4.0,
|
|
ls_init_value: float = None,
|
|
act_layer: Callable = nn.GELU,
|
|
norm_layer: Callable = nn.LayerNorm,
|
|
use_cls_token: bool = False,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.attn = PerceptionEncoderVisionAttention(
|
|
d_model,
|
|
n_head,
|
|
max_grid_height=max_grid_height,
|
|
max_grid_width=max_grid_width,
|
|
use_cls_token=use_cls_token,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
use_data_parallel=use_data_parallel,
|
|
)
|
|
self.ls_1 = (
|
|
PerceptionEncoderLayerScale(d_model, ls_init_value)
|
|
if ls_init_value is not None
|
|
else nn.Identity()
|
|
)
|
|
self.ls_2 = (
|
|
PerceptionEncoderLayerScale(d_model, ls_init_value)
|
|
if ls_init_value is not None
|
|
else nn.Identity()
|
|
)
|
|
self.ln_1 = norm_layer(d_model)
|
|
self.ln_2 = norm_layer(d_model)
|
|
hidden_dim = int(d_model * mlp_ratio)
|
|
self.mlp = PerceptionEncoderMLP(
|
|
d_model,
|
|
hidden_dim,
|
|
act_layer,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
use_data_parallel=use_data_parallel,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
|
|
x = x + self.ls_1(self.attn(self.ln_1(x), grid_hw=grid_hw))
|
|
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
|
return x
|
|
|
|
|
|
class PerceptionEncoderVisionTransformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
width: int,
|
|
layers: int,
|
|
heads: int,
|
|
max_grid_height: int,
|
|
max_grid_width: int,
|
|
mlp_ratio: float = 4.0,
|
|
ls_init_value: float = None,
|
|
act_layer: Callable = nn.GELU,
|
|
norm_layer: Callable = nn.LayerNorm,
|
|
use_cls_token: bool = False,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.width = width
|
|
self.layers = layers
|
|
self.resblocks = nn.ModuleList(
|
|
[
|
|
PerceptionEncoderVisionBlock(
|
|
d_model=width,
|
|
n_head=heads,
|
|
max_grid_height=max_grid_height,
|
|
max_grid_width=max_grid_width,
|
|
mlp_ratio=mlp_ratio,
|
|
ls_init_value=ls_init_value,
|
|
act_layer=act_layer,
|
|
norm_layer=norm_layer,
|
|
use_cls_token=use_cls_token,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.resblocks.{i}",
|
|
use_data_parallel=use_data_parallel,
|
|
)
|
|
for i in range(layers)
|
|
]
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
|
|
for block in self.resblocks:
|
|
x = block(x, grid_hw=grid_hw)
|
|
return x
|
|
|
|
|
|
class PerceptionEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
act_layer: Callable,
|
|
norm_layer: Callable = _DEFAULT_NORM_LAYER,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.patch_size = config.patch_size
|
|
|
|
self.output_dim = config.output_dim or config.width
|
|
self.heads = config.heads
|
|
self.width = config.width
|
|
self.layers = config.layers
|
|
|
|
self.use_abs_posemb = config.use_abs_posemb
|
|
self.use_cls_token = config.use_cls_token
|
|
self.use_rope2d = config.use_rope2d
|
|
if not self.use_rope2d:
|
|
raise ValueError("use_rope2d must be True")
|
|
self.image_size = config.image_size
|
|
|
|
self.conv1 = Conv2dLayer(
|
|
in_channels=3,
|
|
out_channels=config.width,
|
|
kernel_size=config.patch_size,
|
|
stride=config.patch_size,
|
|
bias=False,
|
|
)
|
|
|
|
self.ln_pre = norm_layer(config.width) if config.use_ln_pre else nn.Identity()
|
|
self.ln_post = norm_layer(self.width) if config.use_ln_post else nn.Identity()
|
|
|
|
self.transformer = PerceptionEncoderVisionTransformer(
|
|
config.width,
|
|
config.layers,
|
|
config.heads,
|
|
max_grid_height=self.image_size // self.patch_size,
|
|
max_grid_width=self.image_size // self.patch_size,
|
|
mlp_ratio=config.mlp_ratio,
|
|
ls_init_value=config.ls_init_value,
|
|
act_layer=act_layer,
|
|
norm_layer=norm_layer,
|
|
use_cls_token=self.use_cls_token,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.transformer",
|
|
use_data_parallel=use_data_parallel,
|
|
)
|
|
|
|
self.vit_downsampler1 = Conv2dLayer(
|
|
config.width, config.width * 2, kernel_size=3, stride=2, padding=1
|
|
)
|
|
self.vit_downsampler2 = Conv2dLayer(
|
|
config.width * 2, config.width * 4, kernel_size=3, stride=2, padding=1
|
|
)
|
|
|
|
if self.use_cls_token:
|
|
self.class_embedding = nn.Parameter(
|
|
(self.width**-0.5) * torch.randn(self.width)
|
|
)
|
|
|
|
if self.use_abs_posemb:
|
|
self.posemb_grid_size = self.image_size // self.patch_size
|
|
self.positional_embedding = nn.Parameter(
|
|
(self.width**-0.5)
|
|
* torch.randn(
|
|
int(self.use_cls_token) + self.posemb_grid_size**2,
|
|
self.width,
|
|
)
|
|
)
|
|
|
|
def sample_abs_posemb(self, grid_h: int, grid_w: int):
|
|
if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
|
|
return self.positional_embedding[None, ...]
|
|
|
|
pos_embed = self.positional_embedding
|
|
if self.use_cls_token:
|
|
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
|
|
|
|
pos_embed = (
|
|
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
|
|
.permute(0, 3, 1, 2)
|
|
.contiguous()
|
|
)
|
|
pos_embed = F.interpolate(
|
|
pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
|
|
)
|
|
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width)
|
|
|
|
if self.use_cls_token:
|
|
pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
|
|
|
|
return pos_embed[None, ...]
|
|
|
|
def forward_features(self, x: torch.Tensor):
|
|
batch, _, h, w = x.shape
|
|
grid_h, grid_w = h // self.patch_size, w // self.patch_size
|
|
|
|
x = self.conv1(x)
|
|
x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
|
|
|
|
if self.use_cls_token:
|
|
x = torch.cat(
|
|
[self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x], dim=1
|
|
)
|
|
|
|
if self.use_abs_posemb:
|
|
x = x + self.sample_abs_posemb(grid_h, grid_w)
|
|
|
|
x = self.ln_pre(x)
|
|
x = self.transformer(x, grid_hw=(grid_h, grid_w))
|
|
x = self.ln_post(x)
|
|
|
|
if self.use_cls_token:
|
|
x = x[:, 1:, :]
|
|
|
|
return x
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x = self.forward_features(x)
|
|
B, P, C = x.shape
|
|
T = int(P**0.5)
|
|
x = x.transpose(2, 1).contiguous()
|
|
x = x.view(B, C, T, T)
|
|
|
|
x = self.vit_downsampler1(x)
|
|
x = self.vit_downsampler2(x)
|
|
|
|
B, C, T, T = x.shape
|
|
return x.view(B, -1, T * T).transpose(1, 2)
|
|
|
|
|
|
class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_prefix={
|
|
"model.": "language_model.model.",
|
|
"lm_head.": "language_model.lm_head.",
|
|
},
|
|
orig_to_new_substr={
|
|
".attn.in_proj_weight": ".attn.qkv_proj.weight",
|
|
".attn.in_proj_bias": ".attn.qkv_proj.bias",
|
|
".mlp.c_fc": ".mlp.fc1",
|
|
".mlp.c_proj": ".mlp.fc2",
|
|
},
|
|
)
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
|
super(Step3VLForConditionalGeneration, self).__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
|
|
|
with self._mark_tower_model(vllm_config, "image"):
|
|
self.vision_model = PerceptionEncoder(
|
|
config.vision_config,
|
|
get_act_fn(config.vision_config.hidden_act),
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "vision_model"),
|
|
use_data_parallel=self.use_data_parallel,
|
|
)
|
|
self.vit_large_projector = ColumnParallelLinear(
|
|
config.vision_config.width * 4,
|
|
config.text_config.hidden_size,
|
|
bias=config.projector_bias,
|
|
gather_output=True,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "vit_large_projector"),
|
|
disable_tp=self.use_data_parallel,
|
|
)
|
|
|
|
with self._mark_language_model(vllm_config):
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def _get_vision_model_output(
|
|
self, input_tensor: torch.Tensor | None
|
|
) -> torch.Tensor | None:
|
|
if input_tensor is None:
|
|
return None
|
|
if self.use_data_parallel:
|
|
return run_dp_sharded_vision_model(input_tensor, self.vision_model)
|
|
return self.vision_model(input_tensor)
|
|
|
|
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
|
|
image_features, _ = self.vit_large_projector(image_features)
|
|
return image_features
|