diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 0d413f115..2fd2ae08a 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -720,6 +720,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | |
| `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | ✅︎ |
+| `StepVLForConditionalGeneration` | Step3-VL-10B | T + I+ | `stepfun-ai/Step3-VL-10B` | | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ |
| `Tarsier2ForConditionalGeneration`^ | Tarsier2 | T + IE+ + VE+ | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ |
| `UltravoxModel` | Ultravox | T + AE+ | `fixie-ai/ultravox-v0_5-llama-3_2-1b` | ✅︎ | ✅︎ |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 62f6c92f4..0b7d50725 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -925,6 +925,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Step3VLForConditionalGeneration": _HfExamplesInfo(
"stepfun-ai/step3", trust_remote_code=True
),
+ "StepVLForConditionalGeneration": _HfExamplesInfo(
+ "stepfun-ai/Step3-VL-10B", trust_remote_code=True
+ ),
"UltravoxModel": _HfExamplesInfo(
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
trust_remote_code=True,
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 51e7b9133..ab80feea1 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -323,6 +323,7 @@ _MULTIMODAL_MODELS = {
"hunyuan_vision",
"HunYuanVLForConditionalGeneration",
),
+ "StepVLForConditionalGeneration": ("step_vl", "StepVLForConditionalGeneration"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"),
"OpenCUAForConditionalGeneration": (
diff --git a/vllm/model_executor/models/step_vl.py b/vllm/model_executor/models/step_vl.py
new file mode 100644
index 000000000..ca165f43e
--- /dev/null
+++ b/vllm/model_executor/models/step_vl.py
@@ -0,0 +1,549 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""This is basically a copy from perception_models/core/vision_encoder/pe.py"""
+
+from collections.abc import Callable
+from functools import partial
+
+import torch
+from einops import rearrange, repeat
+from torch import nn
+from torch.nn import functional as F
+
+from vllm.config import VllmConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
+from vllm.model_executor.layers.conv import Conv2dLayer
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+
+from .step3_vl import Step3VLForConditionalGeneration
+from .utils import WeightsMapper, init_vllm_registered_model, maybe_prefix
+from .vision import run_dp_sharded_vision_model
+
+_DEFAULT_NORM_LAYER = partial(nn.LayerNorm, eps=1e-5)
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
+ dtype = t.dtype
+
+ if t.ndim == 3:
+ seq_len = t.shape[seq_dim]
+ freqs = freqs[-seq_len:]
+
+ rot_dim = freqs.shape[-1]
+ end_index = start_index + rot_dim
+
+ assert rot_dim <= t.shape[-1], (
+ "feature dimension {} is not of sufficient size to rotate in all the "
+ "positions {}".format(t.shape[-1], rot_dim)
+ )
+
+ t_left, t, t_right = (
+ t[..., :start_index],
+ t[..., start_index:end_index],
+ t[..., end_index:],
+ )
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
+ out = torch.cat((t_left, t, t_right), dim=-1)
+
+ return out.type(dtype)
+
+
+class PerceptionEncoderRope2D(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ max_grid_height: int,
+ max_grid_width: int,
+ use_cls_token: bool = False,
+ theta=10000,
+ max_freq=10,
+ num_freqs=1,
+ theta_rescale_factor=1.0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.max_grid_height = max_grid_height
+ self.max_grid_width = max_grid_width
+ self.use_cls_token = use_cls_token
+ self.theta = theta * theta_rescale_factor ** (dim / (dim - 2))
+ self.max_freq = max_freq
+ self.num_freqs = num_freqs
+ cache = self._compute_2d_freqs()
+ self.register_buffer("freqs_cache", cache, persistent=False)
+
+ def _compute_inv_freq(self, base: int | float, dim: int) -> torch.Tensor:
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ return freqs
+
+ def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
+ freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype), inv_freq)
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
+ return freqs
+
+ def _compute_2d_freqs(self) -> torch.Tensor:
+ grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
+ grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
+ if self.use_cls_token:
+ grid_h_range += 1
+ grid_w_range += 1
+ inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
+ freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
+ self.max_grid_height, self.max_grid_width, -1
+ )
+ freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
+ self.max_grid_height, self.max_grid_width, -1
+ )
+ freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
+ self.max_grid_height * self.max_grid_width, -1
+ )
+ if self.use_cls_token:
+ freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
+ freqs = freqs[None, None, ...]
+ return freqs
+
+ def forward(self, q: torch.Tensor, k: torch.Tensor, grid_hw: tuple[int, int]):
+ if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
+ rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
+ cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
+ positions = (rows * self.max_grid_width + cols).reshape(-1).to(torch.long)
+ if self.use_cls_token:
+ positions = torch.cat(
+ [torch.zeros(1, device=q.device), positions + 1], dim=0
+ )
+ positions = positions.to(torch.long)
+ freqs = self.freqs_cache.index_select(2, positions)
+ else:
+ freqs = self.freqs_cache
+ q = apply_rotary_emb(freqs, q)
+ k = apply_rotary_emb(freqs, k)
+ return q, k
+
+
+class PerceptionEncoderLayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PerceptionEncoderMLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ act_layer: Callable[[], nn.Module],
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ self.fc1 = ColumnParallelLinear(
+ input_dim,
+ hidden_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc1",
+ disable_tp=use_data_parallel,
+ )
+ self.activation = act_layer
+ self.fc2 = RowParallelLinear(
+ hidden_dim,
+ input_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc2",
+ disable_tp=use_data_parallel,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, _ = self.fc1(x)
+ x = self.activation(x)
+ x, _ = self.fc2(x)
+ return x
+
+
+class PerceptionEncoderVisionAttention(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ max_grid_height: int,
+ max_grid_width: int,
+ use_cls_token: bool = False,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.total_num_heads = num_heads
+ self.head_dim = embed_dim // num_heads
+ self.scale = self.head_dim**-0.5
+
+ tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
+ assert self.total_num_heads % tp_size == 0, (
+ "embed_dim must be divisible by num_heads"
+ )
+ self.num_heads = self.total_num_heads // tp_size
+
+ self.qkv_proj = QKVParallelLinear(
+ embed_dim,
+ self.head_dim,
+ self.total_num_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ disable_tp=use_data_parallel,
+ )
+ self.out_proj = RowParallelLinear(
+ embed_dim,
+ embed_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.out_proj",
+ disable_tp=use_data_parallel,
+ )
+ self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
+ self.rope = PerceptionEncoderRope2D(
+ dim=self.head_dim,
+ max_grid_height=max_grid_height,
+ max_grid_width=max_grid_width,
+ use_cls_token=use_cls_token,
+ )
+
+ def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
+ bsz, seq_len, _ = x.shape
+ qkv, _ = self.qkv_proj(x)
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
+
+ q = q.view(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
+ k = k.view(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
+ q, k = self.rope(q, k, grid_hw=grid_hw)
+ q = q.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads * self.head_dim)
+ k = k.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads * self.head_dim)
+
+ attn_output = self.attn(q, k, v)
+ attn_output, _ = self.out_proj(attn_output)
+ return attn_output
+
+
+class PerceptionEncoderVisionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ max_grid_height: int,
+ max_grid_width: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
+ use_cls_token: bool = False,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ self.attn = PerceptionEncoderVisionAttention(
+ d_model,
+ n_head,
+ max_grid_height=max_grid_height,
+ max_grid_width=max_grid_width,
+ use_cls_token=use_cls_token,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ use_data_parallel=use_data_parallel,
+ )
+ self.ls_1 = (
+ PerceptionEncoderLayerScale(d_model, ls_init_value)
+ if ls_init_value is not None
+ else nn.Identity()
+ )
+ self.ls_2 = (
+ PerceptionEncoderLayerScale(d_model, ls_init_value)
+ if ls_init_value is not None
+ else nn.Identity()
+ )
+ self.ln_1 = norm_layer(d_model)
+ self.ln_2 = norm_layer(d_model)
+ hidden_dim = int(d_model * mlp_ratio)
+ self.mlp = PerceptionEncoderMLP(
+ d_model,
+ hidden_dim,
+ act_layer,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ use_data_parallel=use_data_parallel,
+ )
+
+ def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
+ x = x + self.ls_1(self.attn(self.ln_1(x), grid_hw=grid_hw))
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
+ return x
+
+
+class PerceptionEncoderVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ max_grid_height: int,
+ max_grid_width: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
+ use_cls_token: bool = False,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.ModuleList(
+ [
+ PerceptionEncoderVisionBlock(
+ d_model=width,
+ n_head=heads,
+ max_grid_height=max_grid_height,
+ max_grid_width=max_grid_width,
+ mlp_ratio=mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ use_cls_token=use_cls_token,
+ quant_config=quant_config,
+ prefix=f"{prefix}.resblocks.{i}",
+ use_data_parallel=use_data_parallel,
+ )
+ for i in range(layers)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
+ for block in self.resblocks:
+ x = block(x, grid_hw=grid_hw)
+ return x
+
+
+class PerceptionEncoder(nn.Module):
+ def __init__(
+ self,
+ config,
+ act_layer: Callable,
+ norm_layer: Callable = _DEFAULT_NORM_LAYER,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ self.patch_size = config.patch_size
+
+ self.output_dim = config.output_dim or config.width
+ self.heads = config.heads
+ self.width = config.width
+ self.layers = config.layers
+
+ self.use_abs_posemb = config.use_abs_posemb
+ self.use_cls_token = config.use_cls_token
+ self.use_rope2d = config.use_rope2d
+ if not self.use_rope2d:
+ raise ValueError("use_rope2d must be True")
+ self.image_size = config.image_size
+
+ self.conv1 = Conv2dLayer(
+ in_channels=3,
+ out_channels=config.width,
+ kernel_size=config.patch_size,
+ stride=config.patch_size,
+ bias=False,
+ )
+
+ self.ln_pre = norm_layer(config.width) if config.use_ln_pre else nn.Identity()
+ self.ln_post = norm_layer(self.width) if config.use_ln_post else nn.Identity()
+
+ self.transformer = PerceptionEncoderVisionTransformer(
+ config.width,
+ config.layers,
+ config.heads,
+ max_grid_height=self.image_size // self.patch_size,
+ max_grid_width=self.image_size // self.patch_size,
+ mlp_ratio=config.mlp_ratio,
+ ls_init_value=config.ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ use_cls_token=self.use_cls_token,
+ quant_config=quant_config,
+ prefix=f"{prefix}.transformer",
+ use_data_parallel=use_data_parallel,
+ )
+
+ self.vit_downsampler1 = Conv2dLayer(
+ config.width, config.width * 2, kernel_size=3, stride=2, padding=1
+ )
+ self.vit_downsampler2 = Conv2dLayer(
+ config.width * 2, config.width * 4, kernel_size=3, stride=2, padding=1
+ )
+
+ if self.use_cls_token:
+ self.class_embedding = nn.Parameter(
+ (self.width**-0.5) * torch.randn(self.width)
+ )
+
+ if self.use_abs_posemb:
+ self.posemb_grid_size = self.image_size // self.patch_size
+ self.positional_embedding = nn.Parameter(
+ (self.width**-0.5)
+ * torch.randn(
+ int(self.use_cls_token) + self.posemb_grid_size**2,
+ self.width,
+ )
+ )
+
+ def sample_abs_posemb(self, grid_h: int, grid_w: int):
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
+ return self.positional_embedding[None, ...]
+
+ pos_embed = self.positional_embedding
+ if self.use_cls_token:
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
+
+ pos_embed = (
+ pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+ pos_embed = F.interpolate(
+ pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
+ )
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width)
+
+ if self.use_cls_token:
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
+
+ return pos_embed[None, ...]
+
+ def forward_features(self, x: torch.Tensor):
+ batch, _, h, w = x.shape
+ grid_h, grid_w = h // self.patch_size, w // self.patch_size
+
+ x = self.conv1(x)
+ x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
+
+ if self.use_cls_token:
+ x = torch.cat(
+ [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x], dim=1
+ )
+
+ if self.use_abs_posemb:
+ x = x + self.sample_abs_posemb(grid_h, grid_w)
+
+ x = self.ln_pre(x)
+ x = self.transformer(x, grid_hw=(grid_h, grid_w))
+ x = self.ln_post(x)
+
+ if self.use_cls_token:
+ x = x[:, 1:, :]
+
+ return x
+
+ def forward(self, x: torch.Tensor):
+ x = self.forward_features(x)
+ B, P, C = x.shape
+ T = int(P**0.5)
+ x = x.transpose(2, 1).contiguous()
+ x = x.view(B, C, T, T)
+
+ x = self.vit_downsampler1(x)
+ x = self.vit_downsampler2(x)
+
+ B, C, T, T = x.shape
+ return x.view(B, -1, T * T).transpose(1, 2)
+
+
+class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "model.": "language_model.model.",
+ "lm_head.": "language_model.lm_head.",
+ },
+ orig_to_new_substr={
+ ".attn.in_proj_weight": ".attn.qkv_proj.weight",
+ ".attn.in_proj_bias": ".attn.qkv_proj.bias",
+ ".mlp.c_fc": ".mlp.fc1",
+ ".mlp.c_proj": ".mlp.fc2",
+ },
+ )
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super(Step3VLForConditionalGeneration, self).__init__()
+
+ config = vllm_config.model_config.hf_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ quant_config = vllm_config.quant_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+ self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
+ if multimodal_config.get_limit_per_prompt("image"):
+ self.vision_model = PerceptionEncoder(
+ config.vision_config,
+ get_act_fn(config.vision_config.hidden_act),
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "vision_model"),
+ use_data_parallel=self.use_data_parallel,
+ )
+ self.vit_large_projector = ColumnParallelLinear(
+ config.vision_config.width * 4,
+ config.text_config.hidden_size,
+ bias=config.projector_bias,
+ gather_output=True,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "vit_large_projector"),
+ disable_tp=self.use_data_parallel,
+ )
+ else:
+ self.vision_model = None
+ self.vit_large_projector = None
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ def _get_vision_model_output(
+ self, input_tensor: torch.Tensor | None
+ ) -> torch.Tensor | None:
+ if input_tensor is None:
+ return None
+ if self.use_data_parallel:
+ return run_dp_sharded_vision_model(input_tensor, self.vision_model)
+ return self.vision_model(input_tensor)
+
+ def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
+ image_features, _ = self.vit_large_projector(image_features)
+ return image_features