diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0d413f115..2fd2ae08a 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -720,6 +720,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | | `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | ✅︎ | +| `StepVLForConditionalGeneration` | Step3-VL-10B | T + I+ | `stepfun-ai/Step3-VL-10B` | | ✅︎ | | `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | | `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | | `UltravoxModel` | Ultravox | T + AE+ | `fixie-ai/ultravox-v0_5-llama-3_2-1b` | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 62f6c92f4..0b7d50725 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -925,6 +925,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Step3VLForConditionalGeneration": _HfExamplesInfo( "stepfun-ai/step3", trust_remote_code=True ), + "StepVLForConditionalGeneration": _HfExamplesInfo( + "stepfun-ai/Step3-VL-10B", trust_remote_code=True + ), "UltravoxModel": _HfExamplesInfo( "fixie-ai/ultravox-v0_5-llama-3_2-1b", trust_remote_code=True, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 51e7b9133..ab80feea1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -323,6 +323,7 @@ _MULTIMODAL_MODELS = { "hunyuan_vision", "HunYuanVLForConditionalGeneration", ), + "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"), "OpenCUAForConditionalGeneration": ( diff --git a/vllm/model_executor/models/step_vl.py b/vllm/model_executor/models/step_vl.py new file mode 100644 index 000000000..ca165f43e --- /dev/null +++ b/vllm/model_executor/models/step_vl.py @@ -0,0 +1,549 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""This is basically a copy from perception_models/core/vision_encoder/pe.py""" + +from collections.abc import Callable +from functools import partial + +import torch +from einops import rearrange, repeat +from torch import nn +from torch.nn import functional as F + +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention +from vllm.model_executor.layers.conv import Conv2dLayer +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig + +from .step3_vl import Step3VLForConditionalGeneration +from .utils import WeightsMapper, init_vllm_registered_model, maybe_prefix +from .vision import run_dp_sharded_vision_model + +_DEFAULT_NORM_LAYER = partial(nn.LayerNorm, eps=1e-5) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): + dtype = t.dtype + + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:] + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], ( + "feature dimension {} is not of sufficient size to rotate in all the " + "positions {}".format(t.shape[-1], rot_dim) + ) + + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + out = torch.cat((t_left, t, t_right), dim=-1) + + return out.type(dtype) + + +class PerceptionEncoderRope2D(nn.Module): + def __init__( + self, + dim: int, + max_grid_height: int, + max_grid_width: int, + use_cls_token: bool = False, + theta=10000, + max_freq=10, + num_freqs=1, + theta_rescale_factor=1.0, + ): + super().__init__() + self.dim = dim + self.max_grid_height = max_grid_height + self.max_grid_width = max_grid_width + self.use_cls_token = use_cls_token + self.theta = theta * theta_rescale_factor ** (dim / (dim - 2)) + self.max_freq = max_freq + self.num_freqs = num_freqs + cache = self._compute_2d_freqs() + self.register_buffer("freqs_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: int | float, dim: int) -> torch.Tensor: + freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + return freqs + + def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor): + freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype), inv_freq) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + return freqs + + def _compute_2d_freqs(self) -> torch.Tensor: + grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float) + grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float) + if self.use_cls_token: + grid_h_range += 1 + grid_w_range += 1 + inv_freq = self._compute_inv_freq(self.theta, self.dim // 2) + freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand( + self.max_grid_height, self.max_grid_width, -1 + ) + freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand( + self.max_grid_height, self.max_grid_width, -1 + ) + freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape( + self.max_grid_height * self.max_grid_width, -1 + ) + if self.use_cls_token: + freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0) + freqs = freqs[None, None, ...] + return freqs + + def forward(self, q: torch.Tensor, k: torch.Tensor, grid_hw: tuple[int, int]): + if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width: + rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1) + cols = torch.arange(grid_hw[1], device=q.device).view(1, -1) + positions = (rows * self.max_grid_width + cols).reshape(-1).to(torch.long) + if self.use_cls_token: + positions = torch.cat( + [torch.zeros(1, device=q.device), positions + 1], dim=0 + ) + positions = positions.to(torch.long) + freqs = self.freqs_cache.index_select(2, positions) + else: + freqs = self.freqs_cache + q = apply_rotary_emb(freqs, q) + k = apply_rotary_emb(freqs, k) + return q, k + + +class PerceptionEncoderLayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PerceptionEncoderMLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + act_layer: Callable[[], nn.Module], + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.fc1 = ColumnParallelLinear( + input_dim, + hidden_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, + ) + self.activation = act_layer + self.fc2 = RowParallelLinear( + hidden_dim, + input_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.activation(x) + x, _ = self.fc2(x) + return x + + +class PerceptionEncoderVisionAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + max_grid_height: int, + max_grid_width: int, + use_cls_token: bool = False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim**-0.5 + + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0, ( + "embed_dim must be divisible by num_heads" + ) + self.num_heads = self.total_num_heads // tp_size + + self.qkv_proj = QKVParallelLinear( + embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, + ) + self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale) + self.rope = PerceptionEncoderRope2D( + dim=self.head_dim, + max_grid_height=max_grid_height, + max_grid_width=max_grid_width, + use_cls_token=use_cls_token, + ) + + def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor: + bsz, seq_len, _ = x.shape + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.chunk(chunks=3, dim=-1) + + q = q.view(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = k.view(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + q, k = self.rope(q, k, grid_hw=grid_hw) + q = q.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads * self.head_dim) + k = k.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads * self.head_dim) + + attn_output = self.attn(q, k, v) + attn_output, _ = self.out_proj(attn_output) + return attn_output + + +class PerceptionEncoderVisionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + max_grid_height: int, + max_grid_width: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + use_cls_token: bool = False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.attn = PerceptionEncoderVisionAttention( + d_model, + n_head, + max_grid_height=max_grid_height, + max_grid_width=max_grid_width, + use_cls_token=use_cls_token, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + ) + self.ls_1 = ( + PerceptionEncoderLayerScale(d_model, ls_init_value) + if ls_init_value is not None + else nn.Identity() + ) + self.ls_2 = ( + PerceptionEncoderLayerScale(d_model, ls_init_value) + if ls_init_value is not None + else nn.Identity() + ) + self.ln_1 = norm_layer(d_model) + self.ln_2 = norm_layer(d_model) + hidden_dim = int(d_model * mlp_ratio) + self.mlp = PerceptionEncoderMLP( + d_model, + hidden_dim, + act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + + def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]): + x = x + self.ls_1(self.attn(self.ln_1(x), grid_hw=grid_hw)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class PerceptionEncoderVisionTransformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + max_grid_height: int, + max_grid_width: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + use_cls_token: bool = False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + PerceptionEncoderVisionBlock( + d_model=width, + n_head=heads, + max_grid_height=max_grid_height, + max_grid_width=max_grid_width, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + use_cls_token=use_cls_token, + quant_config=quant_config, + prefix=f"{prefix}.resblocks.{i}", + use_data_parallel=use_data_parallel, + ) + for i in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]): + for block in self.resblocks: + x = block(x, grid_hw=grid_hw) + return x + + +class PerceptionEncoder(nn.Module): + def __init__( + self, + config, + act_layer: Callable, + norm_layer: Callable = _DEFAULT_NORM_LAYER, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.patch_size = config.patch_size + + self.output_dim = config.output_dim or config.width + self.heads = config.heads + self.width = config.width + self.layers = config.layers + + self.use_abs_posemb = config.use_abs_posemb + self.use_cls_token = config.use_cls_token + self.use_rope2d = config.use_rope2d + if not self.use_rope2d: + raise ValueError("use_rope2d must be True") + self.image_size = config.image_size + + self.conv1 = Conv2dLayer( + in_channels=3, + out_channels=config.width, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + + self.ln_pre = norm_layer(config.width) if config.use_ln_pre else nn.Identity() + self.ln_post = norm_layer(self.width) if config.use_ln_post else nn.Identity() + + self.transformer = PerceptionEncoderVisionTransformer( + config.width, + config.layers, + config.heads, + max_grid_height=self.image_size // self.patch_size, + max_grid_width=self.image_size // self.patch_size, + mlp_ratio=config.mlp_ratio, + ls_init_value=config.ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + use_cls_token=self.use_cls_token, + quant_config=quant_config, + prefix=f"{prefix}.transformer", + use_data_parallel=use_data_parallel, + ) + + self.vit_downsampler1 = Conv2dLayer( + config.width, config.width * 2, kernel_size=3, stride=2, padding=1 + ) + self.vit_downsampler2 = Conv2dLayer( + config.width * 2, config.width * 4, kernel_size=3, stride=2, padding=1 + ) + + if self.use_cls_token: + self.class_embedding = nn.Parameter( + (self.width**-0.5) * torch.randn(self.width) + ) + + if self.use_abs_posemb: + self.posemb_grid_size = self.image_size // self.patch_size + self.positional_embedding = nn.Parameter( + (self.width**-0.5) + * torch.randn( + int(self.use_cls_token) + self.posemb_grid_size**2, + self.width, + ) + ) + + def sample_abs_posemb(self, grid_h: int, grid_w: int): + if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w: + return self.positional_embedding[None, ...] + + pos_embed = self.positional_embedding + if self.use_cls_token: + cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:] + + pos_embed = ( + pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + ) + pos_embed = F.interpolate( + pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False + ) + pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width) + + if self.use_cls_token: + pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0) + + return pos_embed[None, ...] + + def forward_features(self, x: torch.Tensor): + batch, _, h, w = x.shape + grid_h, grid_w = h // self.patch_size, w // self.patch_size + + x = self.conv1(x) + x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width) + + if self.use_cls_token: + x = torch.cat( + [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x], dim=1 + ) + + if self.use_abs_posemb: + x = x + self.sample_abs_posemb(grid_h, grid_w) + + x = self.ln_pre(x) + x = self.transformer(x, grid_hw=(grid_h, grid_w)) + x = self.ln_post(x) + + if self.use_cls_token: + x = x[:, 1:, :] + + return x + + def forward(self, x: torch.Tensor): + x = self.forward_features(x) + B, P, C = x.shape + T = int(P**0.5) + x = x.transpose(2, 1).contiguous() + x = x.view(B, C, T, T) + + x = self.vit_downsampler1(x) + x = self.vit_downsampler2(x) + + B, C, T, T = x.shape + return x.view(B, -1, T * T).transpose(1, 2) + + +class StepVLForConditionalGeneration(Step3VLForConditionalGeneration): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + }, + orig_to_new_substr={ + ".attn.in_proj_weight": ".attn.qkv_proj.weight", + ".attn.in_proj_bias": ".attn.qkv_proj.bias", + ".mlp.c_fc": ".mlp.fc1", + ".mlp.c_proj": ".mlp.fc2", + }, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super(Step3VLForConditionalGeneration, self).__init__() + + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + quant_config = vllm_config.quant_config + + self.config = config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config.get_limit_per_prompt("image"): + self.vision_model = PerceptionEncoder( + config.vision_config, + get_act_fn(config.vision_config.hidden_act), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel, + ) + self.vit_large_projector = ColumnParallelLinear( + config.vision_config.width * 4, + config.text_config.hidden_size, + bias=config.projector_bias, + gather_output=True, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vit_large_projector"), + disable_tp=self.use_data_parallel, + ) + else: + self.vision_model = None + self.vit_large_projector = None + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _get_vision_model_output( + self, input_tensor: torch.Tensor | None + ) -> torch.Tensor | None: + if input_tensor is None: + return None + if self.use_data_parallel: + return run_dp_sharded_vision_model(input_tensor, self.vision_model) + return self.vision_model(input_tensor) + + def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor: + image_features, _ = self.vit_large_projector(image_features) + return image_features