Files
vllm/vllm/model_executor/models/qwen3_vl.py

2345 lines
90 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import lru_cache, partial
from itertools import islice
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
smart_resize as image_smart_resize,
)
from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLConfig,
Qwen3VLVisionConfig,
)
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize,
)
from transformers.video_utils import VideoMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group, parallel_state
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.attention.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.evs import (
compute_mrope_for_media,
compute_retained_tokens_count,
compute_retention_mask,
recompute_mrope_positions,
)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalFieldElem,
MultiModalKwargsItem,
MultiModalKwargsItems,
PlaceholderRange,
VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.tokenizers.protocol import TokenizerLike
from vllm.tokenizers.registry import cached_tokenizer_from_config
from vllm.utils.collection_utils import is_list_of
from vllm.utils.math_utils import round_up
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsMultiModalPruning,
SupportsPP,
_require_is_multimodal,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs,
Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLVideoEmbeddingInputs,
Qwen2_5_VLVideoInputs,
Qwen2_5_VLVideoPixelInputs,
)
from .qwen2_vl import (
Qwen2VLMultiModalDataParser,
Qwen2VLProcessingInfo,
_create_qwen2vl_field_factory,
)
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import (
get_vit_attn_backend,
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__)
# We use 2048 dummy video frames that would generate vision embeddings
# of the maximum size.
DUMMY_VIDEO_NUM_FRAMES = 2048
class Qwen3_VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
hidden_size: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = Conv3dLayer(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x
class Qwen3_VisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int,
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.linear_fc1 = ColumnParallelLinear(
in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel,
)
self.linear_fc2 = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel,
)
self.act_fn = act_fn
def forward(self, x: torch.Tensor):
mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
return mlp_output
class Qwen3_VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Qwen2_5_VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.mlp = Qwen3_VisionMLP(
dim,
mlp_hidden_dim,
act_fn=act_fn,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor, # Only used for Flash Attention
sequence_lengths: torch.Tensor, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
x = x + self.mlp(self.norm2(x))
return x
class Qwen3_VisionPatchMerger(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = is_vit_use_data_parallel()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
if self.use_postshuffle_norm:
context_dim = self.hidden_size
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm = norm_layer(context_dim)
self.linear_fc1 = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel,
)
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(
self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
x = self.norm(x.view(-1, self.hidden_size))
else:
x = self.norm(x).view(-1, self.hidden_size)
x_parallel, _ = self.linear_fc1(x)
x_parallel = self.act_fn(x_parallel)
out, _ = self.linear_fc2(x_parallel)
return out
class Qwen3_VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.num_position_embeddings = vision_config.num_position_embeddings
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = (
vision_config.deepstack_visual_indexes
if hasattr(vision_config, "deepstack_visual_indexes")
else []
)
self.num_grid_per_side = int(self.num_position_embeddings**0.5)
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1
if use_data_parallel
else parallel_state.get_tensor_model_parallel_world_size()
)
# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
self.out_hidden_size = vision_config.out_hidden_size * (
1 + len(self.deepstack_visual_indexes)
)
self.patch_embed = Qwen3_VisionPatchEmbed(
patch_size=self.patch_size,
temporal_patch_size=self.temporal_patch_size,
in_channels=vision_config.in_channels,
hidden_size=self.hidden_size,
)
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope(
head_size=head_dim,
max_position=8192,
is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
)
self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
)
self.deepstack_merger_list = nn.ModuleList(
[
Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size,
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
)
self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(vision_config.depth)
]
)
@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device
@staticmethod
@lru_cache(maxsize=1024)
def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor:
hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w))
h_div = h // spatial_merge_size
w_div = w // spatial_merge_size
hpos_ids = hpos_ids.reshape(
h_div,
spatial_merge_size,
w_div,
spatial_merge_size,
)
hpos_ids = hpos_ids.transpose(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w))
wpos_ids = wpos_ids.reshape(
h_div,
spatial_merge_size,
w_div,
spatial_merge_size,
)
wpos_ids = wpos_ids.transpose(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1))
def rot_pos_emb(self, grid_thw: list[list[int]]):
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
pos_ids = [
self.rot_pos_ids(h, w, self.spatial_merge_size)
if t == 1
else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1)
for t, h, w in grid_thw
]
pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
m_size = self.spatial_merge_size
hidden_dim = self.pos_embed.embedding_dim
outputs = []
for t, h, w in grid_thw:
h_idxs = torch.linspace(
0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device
)
w_idxs = torch.linspace(
0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device
)
h_floor = h_idxs.to(torch.long)
w_floor = w_idxs.to(torch.long)
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)
dh = h_idxs - h_floor
dw = w_idxs - w_floor
# Create meshgrid view for all h, w vars
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11 = dh_grid * dw_grid
w10 = dh_grid - w11
w01 = dw_grid - w11
w00 = 1 - dh_grid - w01
h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
h_grid_idx = h_grid * num_grid_per_side
indices = (h_grid_idx + w_grid).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
weights = weights.to(dtype=self.dtype)
embeds = self.pos_embed(indices)
embeds *= weights
combined = embeds.sum(dim=0)
combined = combined.reshape(
h // m_size, m_size, w // m_size, m_size, hidden_dim
)
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
outputs.append(repeated)
return torch.cat(outputs, dim=0)
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
hidden_states = self.patch_embed(hidden_states)
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
grid_thw = grid_thw.numpy()
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens, self.device
)
max_seqlen = torch.tensor(
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
dtype=torch.int32,
)
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
self.attn_backend,
cu_seqlens,
self.hidden_size,
self.tp_size,
self.device,
)
hidden_states = hidden_states.unsqueeze(1)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx](
hidden_states
)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
hidden_states = torch.cat(
[hidden_states] + deepstack_feature_lists, dim=1
) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen3VLConfig)
def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
return self.ctx.get_hf_processor(
Qwen3VLProcessor,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast:
return self.get_hf_processor(**kwargs).image_processor
def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
return self.get_hf_processor(**kwargs).video_processor
def get_data_parser(self):
return Qwen2VLMultiModalDataParser(
self.get_hf_config().vision_config.spatial_merge_size,
video_needs_metadata=True,
expected_hidden_size=self._get_expected_hidden_size(),
)
def _get_vision_info(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 2,
do_resize: bool = True,
image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]:
is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = image_processor.size
if override_size := mm_kwargs.get("size"):
size = size | override_size
if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
size = size | {"shortest_edge": override_min_pixels}
if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
size = size | {"longest_edge": override_max_pixels}
if do_resize:
if is_video:
smart_resize = video_smart_resize
extra_kwargs = {
"num_frames": num_frames,
"temporal_factor": temporal_patch_size,
}
else:
smart_resize = image_smart_resize
extra_kwargs = {}
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=size["shortest_edge"],
max_pixels=size["longest_edge"],
**extra_kwargs,
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:
preprocessed_size = ImageSize(width=image_width, height=image_height)
padded_num_frames = round_up(num_frames, temporal_patch_size)
grid_t = max(padded_num_frames // temporal_patch_size, 1)
grid_h = preprocessed_size.height // patch_size
grid_w = preprocessed_size.width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches // (merge_size**2)
return preprocessed_size, num_vision_tokens
def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int:
return super()._get_max_video_frames(
max_tokens, start_num_frames=start_num_frames
)
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
return super().get_num_frames_with_most_features(
seq_len, mm_counts, max_frames_per_video=DUMMY_VIDEO_NUM_FRAMES
)
def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
video_processor = self.get_video_processor()
mm_kwargs = self.ctx.get_merged_mm_kwargs({})
video_size = mm_kwargs.get("size", video_processor.size)
temporal_patch_size = mm_kwargs.get(
"temporal_patch_size", video_processor.temporal_patch_size
)
# video_max_pixels contains the temporal compression factor,
# so we divide by 2 to get the maximum number of image pixels.
video_max_pixels = video_size["longest_edge"]
target_width, target_height = self.get_image_size_with_most_features(
max_pixels=video_max_pixels // temporal_patch_size
)
num_video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=2,
image_processor=video_processor,
mm_kwargs={},
)
return num_video_soft_tokens
def _calculate_timestamps(
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
):
if not isinstance(indices, list):
indices = indices.tolist()
if len(indices) % merge_size != 0:
# don't update metadata's frames_indices directly
indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size)
timestamps = [idx / video_fps for idx in indices]
timestamps = [
(timestamps[i] + timestamps[i + merge_size - 1]) / 2
for i in range(0, len(timestamps), merge_size)
]
return timestamps
def _get_video_second_idx(
self,
metadata: dict[str, Any],
do_sample_frames: bool | None = None,
sampled_fps: float | None = None,
sampled_num_frames: int | None = None,
) -> list[int]:
video_processor = self.get_video_processor()
temporal_patch_size = video_processor.temporal_patch_size
indices = metadata["frames_indices"]
# metadata["fps"] refers to the true fps of the input video.
video_fps = metadata["fps"]
if do_sample_frames is None:
do_sample_frames = metadata.get("do_sample_frames", False)
# If video frames are sampled in HF processor (instead of vLLM
# video loader), we need to re-calculate the indices from original
# metadata.
if do_sample_frames:
total_num_frames = metadata["total_num_frames"]
# When num_frames is explicitly provided, use it directly
# instead of computing from fps. This mirrors the behavior of
# HF's Qwen3VLVideoProcessor.sample_frames where num_frames
# and fps are mutually exclusive.
if sampled_num_frames is not None:
num_frames = sampled_num_frames
else:
# here video_fps is the fps of the sampled video, and
# metadata["fps"] refers to the fps of the original video.
sampled_fps = sampled_fps if sampled_fps else video_processor.fps
num_frames = int(total_num_frames / metadata["fps"] * sampled_fps)
num_frames = min(
min(
max(num_frames, video_processor.min_frames),
video_processor.max_frames,
),
total_num_frames,
)
indices = (
np.linspace(0, total_num_frames - 1, num_frames)
.round()
.astype(int)
.tolist()
)
timestamps = self._calculate_timestamps(indices, video_fps, temporal_patch_size)
return timestamps
class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
video_token = "<|vision_start|><|video_pad|><|vision_end|>"
return image_token * num_images + video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_overrides = mm_options.get("image")
video_overrides = mm_options.get("video")
target_image_width, target_image_height = (
self.info.get_image_size_with_most_features()
)
# treat videos as special images
target_num_frames = 2
if video_overrides:
assert isinstance(video_overrides, VideoDummyOptions)
num_frames_override = video_overrides.num_frames
if num_frames_override:
if num_frames_override > target_num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
num_frames_override,
target_num_frames,
)
if num_frames_override < 2:
logger.warning(
"video.num_frames override (%d) cannot be less "
"than 2, will be ignored",
num_frames_override,
)
target_num_frames = min(target_num_frames, num_frames_override)
target_num_frames = max(target_num_frames, 2)
video_processor = self.info.get_video_processor()
mm_kwargs = self.info.ctx.get_merged_mm_kwargs({})
video_size = mm_kwargs.get("size", video_processor.size)
temporal_patch_size = mm_kwargs.get(
"temporal_patch_size", video_processor.temporal_patch_size
)
# video_max_pixels contains the temporal compression factor,
# so we divide by 2 to get the maximum number of image pixels.
video_max_pixels = video_size["longest_edge"]
target_video_width, target_video_height = (
self.info.get_image_size_with_most_features(
max_pixels=video_max_pixels // temporal_patch_size
)
)
target_video_size, _ = self.info._get_vision_info(
image_width=target_video_width,
image_height=target_video_height,
num_frames=target_num_frames,
image_processor=video_processor,
mm_kwargs={},
)
# NOTE: we need to do this check here since Qwen3-VL resizes video
# frames depending on how many frames there are.
target_video_width, target_video_height = (
target_video_size.width,
target_video_size.height,
)
if video_overrides:
assert isinstance(video_overrides, VideoDummyOptions)
width_override = video_overrides.width
if width_override:
if width_override > target_video_width:
logger.warning(
"video.width override (%d) exceeds model's "
"maximum width (%d), will be ignored",
width_override,
target_video_width,
)
target_video_width = min(target_video_width, width_override)
height_override = video_overrides.height
if height_override:
if height_override > target_video_height:
logger.warning(
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
height_override,
target_video_height,
)
target_video_height = min(target_video_height, height_override)
return {
"image": self._get_dummy_images(
width=target_image_width,
height=target_image_height,
num_images=num_images,
overrides=image_overrides,
),
"video": self._get_dummy_videos(
width=target_video_width,
height=target_video_height,
num_frames=target_num_frames,
num_videos=num_videos,
),
}
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[VideoItem]:
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
video_items = []
for i in range(num_videos):
video_metadata = {
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": [i for i in range(num_frames)],
"video_backend": "opencv",
"do_sample_frames": False,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
return video_items
class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
processor = self.info.get_hf_processor(**mm_kwargs)
# Separate video processing from image processing. Because the videos
# are processed into several image patches
if videos := mm_data.pop("videos", []):
video_grid_thw_lst = []
pixel_values_videos_lst = []
timestamps_per_video = []
for item in videos:
video_array, metadata = item
# NOTE: @JJJYmmm new attr metadata.frames_indices indicates
# the sampled frames indices of pre-sampled videos, which is
# used to calculate the timestamps. Make sure that
# do_sample_frames in mm_kwargs is false for presampled videos.
# NOTE: a copy of is created to update do_sample_frames,
# otherwise mm_hash for the object will be incorrect.
video_mm_kwargs = dict(**mm_kwargs)
if "do_sample_frames" not in video_mm_kwargs:
# qwen_vl_utils already has "do_sample_frames" in
# mm_kwargs, don't overwrite it.
video_mm_kwargs["do_sample_frames"] = metadata.get(
"do_sample_frames", False
)
metadata = VideoMetadata(
**{k: metadata[k] for k in metadata if k != "do_sample_frames"}
)
# Compute timestamps here where we have access to metadata
timestamps = self.info._get_video_second_idx(
metadata=metadata,
do_sample_frames=video_mm_kwargs["do_sample_frames"],
sampled_fps=video_mm_kwargs.get("fps"),
sampled_num_frames=video_mm_kwargs.get("num_frames"),
)
timestamps_per_video.append(timestamps)
video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[metadata]]
# When num_frames is specified, explicitly set fps=None
# to prevent HF's BaseVideoProcessor.preprocess() from
# filling in the class default (fps=2) via setdefault(),
# which would conflict with num_frames (mutually exclusive).
if "num_frames" in video_mm_kwargs and "fps" not in video_mm_kwargs:
video_mm_kwargs["fps"] = None
video_outputs = super()._call_hf_processor(
prompt="<|vision_start|><|video_pad|><|vision_end|>",
mm_data=video_mm_data,
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
merge_size = processor.video_processor.merge_size
# Get video grid info for EVS calculation.
video_grid_thw = video_outputs["video_grid_thw"]
num_frames = int(video_grid_thw[0, 0])
tokens_per_frame_base = int(video_grid_thw[0, 1:].prod()) // (
merge_size**2
)
# Apply EVS if enabled.
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_per_frame_base,
num_frames=num_frames,
q=video_pruning_rate,
)
# Here we just need placeholders that won't actually be replaced -
# we just need to make sure the total number of tokens is correct
# assign all tokens to the first frame.
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
select_token_id = False
else:
tokens_per_frame = [tokens_per_frame_base] * num_frames
select_token_id = True
# Generate the video replacement with EVS-adjusted token counts
tokenizer = self.info.get_tokenizer()
hf_config = self.info.get_hf_config()
video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=tokens_per_frame,
timestamps=timestamps,
tokenizer=tokenizer,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
video_token_id=hf_config.video_token_id,
select_token_id=select_token_id,
)
# Convert token IDs to text for the HF processor flow
video_placeholder = tokenizer.decode(
video_repl.full, skip_special_tokens=False
)
input_ids = video_outputs.pop("input_ids")
video_placeholder = processor.tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace(
"<|vision_start|><|video_pad|><|vision_end|>",
video_placeholder,
1,
)
video_grid_thw_lst.append(video_outputs["video_grid_thw"])
pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
video_outputs = dict(
pixel_values_videos=torch.cat(pixel_values_videos_lst),
video_grid_thw=torch.cat(video_grid_thw_lst),
timestamps=timestamps_per_video,
)
else:
video_outputs = dict()
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
combined_outputs = dict(
processed_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _create_qwen2vl_field_factory(
self.info.get_hf_config().vision_config.spatial_merge_size
)(hf_inputs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
vision_end_token_id = hf_config.vision_end_token_id
merge_length = image_processor.merge_size**2
def get_image_replacement_qwen3vl(item_idx: int):
out_item = out_mm_kwargs["image"][item_idx]
grid_thw = out_item["image_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
num_tokens = int(grid_thw.prod()) // merge_length
return [hf_processor.image_token_id] * num_tokens
def get_video_replacement_qwen3vl(item_idx: int):
out_item = out_mm_kwargs["video"][item_idx]
grid_thw = out_item["video_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
sampled_fps = hf_processor_mm_kwargs.get("fps")
if is_list_of(sampled_fps, float):
sampled_fps = sampled_fps[item_idx]
timestamps = out_item["timestamps"].data
assert len(timestamps) == grid_thw[0], (
f"The timestamps length({len(timestamps)}) should be equal "
f"video length ({grid_thw[0]})."
)
# Compute tokens per frame, with EVS support
num_frames = int(grid_thw[0])
tokens_per_frame_base = int(grid_thw[1:].prod()) // merge_length
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_per_frame_base,
num_frames=num_frames,
q=video_pruning_rate,
)
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
select_token_id = False
else:
tokens_per_frame = [tokens_per_frame_base] * num_frames
select_token_id = True
return Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=tokens_per_frame,
timestamps=timestamps,
tokenizer=tokenizer,
vision_start_token_id=vision_start_token_id,
vision_end_token_id=vision_end_token_id,
video_token_id=video_token_id,
select_token_id=select_token_id,
)
return [
PromptReplacement(
modality="image",
target=hf_processor.image_token,
replacement=get_image_replacement_qwen3vl,
),
# NOTE: We match string on purpose since searching sequence of
# token ids takes more time.
PromptReplacement(
modality="video",
target="<|vision_start|><|video_pad|><|vision_end|>",
replacement=get_video_replacement_qwen3vl,
),
]
@staticmethod
def get_video_repl(
*,
tokens_per_frame: list[int],
timestamps: list[float | int],
tokenizer: TokenizerLike,
vision_start_token_id: int,
vision_end_token_id: int,
video_token_id: int,
select_token_id: bool = False,
) -> PromptUpdateDetails[list[int]]:
"""Build prompt replacement for a video in Qwen3VL format.
The replacement structure for each frame is:
timestamp_tokens + vision_start_token + video_tokens + vision_end_token
Args:
tokens_per_frame: Number of video tokens per frame (can vary per frame for
EVS).
timestamps: List of timestamps in seconds for each frame
tokenizer: Tokenizer to encode timestamp strings
vision_start_token_id: Token ID for vision start marker
vision_end_token_id: Token ID for vision end marker
video_token_id: Token ID for video content
Returns:
PromptUpdateDetails with full token sequence
"""
assert len(timestamps) == len(tokens_per_frame), (
"timestamps and tokens_per_frame must have the same length"
)
# Tokenize timestamp strings independently to avoid tokenizer merging
# tokens across boundaries.
# TODO: switch to `_seq2tokens` which has some caching.
timestamp_token_ids = [
tokenizer.encode(f"<{timestamp:.1f} seconds>", add_special_tokens=False)
for timestamp in timestamps
]
# Build the full token sequence
all_token_ids = []
for frame_timestamp_ids, num_tokens in zip(
timestamp_token_ids, tokens_per_frame
):
# Add timestamp tokens
all_token_ids.extend(frame_timestamp_ids)
# Add vision tokens: vision_start + video_tokens + vision_end
all_token_ids.append(vision_start_token_id)
all_token_ids.extend([video_token_id] * num_tokens)
all_token_ids.append(vision_end_token_id)
if select_token_id:
return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)
# NOTE: we use `from_seq` instead of `select_token_id` because we want all
# tokens in the placeholder to be initially marked as candidates. Then
# in `get_input_embeddings``, we refine the mask to only replace
# `video_token_id` / `image_token_id`` positions with video/image embeddings,
# keeping text embeddings for timestamps and structural tokens.
return PromptUpdateDetails.from_seq(all_token_ids)
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
# the same shape as input_embeds
"deepstack_input_embeds": 0,
}
)
class Qwen3LLMModel(Qwen3Model):
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
# args for deepstack
deepstack_input_embeds: IntermediateTensors | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if deepstack_input_embeds is not None and layer_idx in range(
0, len(deepstack_input_embeds)
):
hidden_states = (
hidden_states
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3ForCausalLM, self).__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Qwen3LLMModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix="lm_head",
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
@MULTIMODAL_REGISTRY.register_processor(
Qwen3VLMultiModalProcessor,
info=Qwen3VLProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsLoRA,
SupportsPP,
SupportsMRoPE,
SupportsEagle,
SupportsEagle3,
SupportsMultiModalPruning,
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"qkv": ["qkv"], # For vision tower's already-packed QKV
}
supports_encoder_tp_data = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.visual.": "visual.",
"lm_head.": "language_model.lm_head.",
"model.language_model.": "language_model.model.",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if modality.startswith("video"):
return "<|vision_start|><|video_pad|><|vision_end|>"
raise ValueError("Only image or video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
super().__init__()
config: Qwen3VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self._tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.is_multimodal_pruning_enabled = (
multimodal_config.is_multimodal_pruning_enabled()
)
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
self.deepstack_num_level = (
len(config.vision_config.deepstack_visual_indexes)
if self.use_deepstack
else 0
)
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
# register buffer for deepstack
if self.use_deepstack:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
with self._mark_language_model(vllm_config):
self.language_model = Qwen3LLMForCausalLM(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
)
if not get_pp_group().is_first_rank and hasattr(
config.vision_config, "deepstack_visual_indexes"
):
assert self.language_model.start_layer >= len(
config.vision_config.deepstack_visual_indexes
), (
"start_layer should be greater than or equal to "
"len(deepstack_visual_indexes)"
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _get_deepstack_input_embeds(
self,
num_tokens: int,
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped
# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(
{
f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
:num_tokens
]
for idx in range(self.deepstack_num_level)
}
)
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0):
self.deepstack_input_embeds = [
torch.zeros(
num_tokens,
self.config.text_config.hidden_size,
device=self.deepstack_input_embeds[0].device,
dtype=self.deepstack_input_embeds[0].dtype,
)
for _ in range(self.deepstack_num_level)
]
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].copy_(
deepstack_input_embeds[idx]
)
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Qwen2_5_VLImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
return Qwen2_5_VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
if image_embeds is not None:
return Qwen2_5_VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self, **kwargs: object
) -> Qwen2_5_VLVideoInputs | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
timestamps = kwargs.pop("timestamps", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
return Qwen2_5_VLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
timestamps=timestamps,
)
if video_embeds is not None:
return Qwen2_5_VLVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
timestamps=timestamps,
)
def _process_image_input(
self, image_input: Qwen2_5_VLImageInputs
) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return image_embeds.split(sizes)
def _process_video_input(
self, video_input: Qwen2_5_VLVideoInputs
) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype
)
if self.use_data_parallel:
grid_thw_list = grid_thw.tolist()
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
)
else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return video_embeds.split(sizes)
def _postprocess_image_embeds_evs(
self,
image_embeds_split: tuple[torch.Tensor, ...],
image_input: Qwen2_5_VLImageInputs,
) -> tuple[torch.Tensor, ...]:
"""
Append mrope positions for each for images.
This is necessary to recover correct mrope
positions after video pruning
Args:
image_embeds_split: Tuple of image embeddings for
each image item.
image_input: Image input data.
Returns:
Tuple of image embeddings for each image item.
Resulting embeddings will have extra 5 channels for
computed mrope positions, consistent with video embeddings.
"""
if self.is_multimodal_pruning_enabled:
merge_size = self.visual.spatial_merge_size
grid_thw = image_input["image_grid_thw"]
grid_thw_list = grid_thw.tolist()
image_embeds_out = []
for emb, size in zip(image_embeds_split, grid_thw_list):
positions = compute_mrope_for_media(size, merge_size).to(emb.device)
positions = torch.cat(
[
positions,
torch.zeros_like(
positions[:, 0:1]
), # Dummy extra fifth channel
],
dim=1,
)
emb = torch.cat([emb, positions], dim=1)
image_embeds_out.append(emb)
image_embeds_split = tuple(image_embeds_out)
return image_embeds_split
def _postprocess_video_embeds_evs(
self,
video_embeds_split: tuple[torch.Tensor, ...],
video_input: Qwen2_5_VLVideoInputs,
) -> tuple[torch.Tensor, ...]:
"""
Prunes video embeddings via Efficient Video Sampling (EVS)
and then appends mrope positions for each retained embeddings
Args:
video_embeds_split: Tuple of video embeddings for each video item.
video_input: Video input data.
Returns:
Tuple of video embeddings for each video item.
Resulting embeddings will have extra 5 channels for computed mrope
positions, and whether the index corresponds to a video embedding.
"""
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
merge_size = self.visual.spatial_merge_size
# Apply EVS to each video.
video_embeds_out = []
for video_idx, (emb, size) in enumerate(zip(video_embeds_split, grid_thw_list)):
# Compute positions.
timestamps = video_input.timestamps[video_idx]
num_frames = len(timestamps)
t, h, w = size
if self.is_multimodal_pruning_enabled:
# For each video, compute retention mask using EVS.
# retention_mask: [11424].
retention_mask = compute_retention_mask(
emb,
size,
spatial_merge_size=self.visual.spatial_merge_size,
q=self.video_pruning_rate,
)
# Apply retention mask.
emb = emb[retention_mask]
# Calculate the actual number of retained tokens per frame.
num_frames, rows, cols = (
t,
h // merge_size,
w // merge_size,
)
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
num_tokens_per_frame = (
retention_mask_thw.sum(dim=(1, 2)).long().tolist()
)
else:
feature_size = emb.shape[0] // num_frames
num_tokens_per_frame = [feature_size] * num_frames
retention_mask = None
emb = self._create_final_video_embeddings(
video_embeddings=emb,
num_tokens_per_frame=num_tokens_per_frame,
timestamps=timestamps,
video_grid_thw=size,
retention_mask=retention_mask,
)
video_embeds_out.append(emb)
return tuple(video_embeds_out)
def _create_final_video_embeddings(
self,
video_embeddings: torch.Tensor,
num_tokens_per_frame: list[int],
timestamps: list[float],
video_grid_thw: list[int],
retention_mask: torch.Tensor,
) -> torch.Tensor:
"""Create final embeddings that combine video embeddings with
text embeddings of indicator tokens.
These final embeddings contain:
- Actual video embeddings in positions corresponding to video content
- Text embeddings for indicator tokens (<img>, </img>, and
frame separation text) in their respective positions
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
"""
device = video_embeddings.device
# Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized
# special tokens to ensure consistent tokenization regardless of
# num_tokens_per_frame values.
video_repl = Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=num_tokens_per_frame,
tokenizer=self._tokenizer,
timestamps=timestamps,
vision_start_token_id=self.config.vision_start_token_id,
vision_end_token_id=self.config.vision_end_token_id,
video_token_id=self.config.video_token_id,
select_token_id=self.is_multimodal_pruning_enabled,
)
repl_token_ids = torch.tensor(video_repl.full, device=device)
embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
is_video_embed = torch.isin(repl_token_ids, embed_token_id)
# Get text embeddings for indicator tokens (has only `visual_dim``).
text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids)
if self.use_deepstack:
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=text_embeddings,
multimodal_embeddings=[video_embeddings],
is_multimodal=is_video_embed,
)
else:
deepstack_input_embeds = None
multimodal_embeddings = [video_embeddings]
merged_embeddings = _merge_multimodal_embeddings(
inputs_embeds=text_embeddings,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_video_embed,
)
to_concat = [merged_embeddings]
if deepstack_input_embeds is not None:
to_concat.append(
deepstack_input_embeds.permute(1, 0, 2).reshape(
deepstack_input_embeds.shape[1], -1
)
)
expanded_positions = None
if self.is_multimodal_pruning_enabled:
is_vision_start = repl_token_ids.eq(self.config.vision_start_token_id)
expanded_positions = self._get_expanded_positions(
device=merged_embeddings.device,
seq_len=merged_embeddings.shape[0],
video_grid_thw=video_grid_thw,
num_tokens_per_frame=num_tokens_per_frame,
timestamps=timestamps,
is_video_embed=is_video_embed,
is_vision_start=is_vision_start,
retention_mask=retention_mask,
)
to_concat.append(expanded_positions)
final_video_embeddings = torch.cat(to_concat, dim=-1)
return final_video_embeddings
def _get_expanded_positions(
self,
device,
seq_len,
video_grid_thw,
num_tokens_per_frame,
timestamps,
is_video_embed,
is_vision_start,
retention_mask,
):
embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
# Expand positions to match the full sequence length
# (includes both video tokens and indicator tokens)
# Shape: [full_length, 5] where positions are filled for video tokens
# and zeros for indicator tokens.
# Channel 3 flags VISION_START tokens so that
# recompute_mrope_positions can reliably count timestamp tokens
# (even when early frames have all video tokens pruned).
# Channel 4 flags video-embedding tokens.
expanded_positions = torch.zeros(
seq_len,
5, # [t_index, h_index, w_index, is_vision_start, is_video]
device=device,
dtype=torch.long,
)
_, h, w = video_grid_thw
merge_size = self.visual.spatial_merge_size
num_frames = len(num_tokens_per_frame)
unpruned_token_ids = Qwen3VLMultiModalProcessor.get_video_repl(
tokens_per_frame=[(h // merge_size) * (w // merge_size)] * num_frames,
tokenizer=self._tokenizer,
timestamps=timestamps,
vision_start_token_id=self.config.vision_start_token_id,
vision_end_token_id=self.config.vision_end_token_id,
video_token_id=self.config.video_token_id,
).full
unpruned_token_ids_tensor = torch.tensor(unpruned_token_ids, device=device)
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"video_grid_thw": MultiModalFieldElem(
data=torch.tensor(video_grid_thw),
field=None, # HACK.
),
}
),
modality="video",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=0, length=len(unpruned_token_ids)),
)
original_mrope = (
self.get_mrope_input_positions(
input_tokens=unpruned_token_ids,
mm_features=[mm_feature],
)[0]
.to(device)
.permute(1, 0)
)
full_is_video_embed = unpruned_token_ids_tensor == embed_token_id
expanded_positions[is_video_embed, :3] = original_mrope[full_is_video_embed][
retention_mask
]
expanded_positions[~is_video_embed, :3] = original_mrope[~full_is_video_embed]
expanded_positions[..., 3] = is_vision_start
expanded_positions[..., 4] = is_video_embed
return expanded_positions
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
for input_key in kwargs:
if (
input_key in ("pixel_values", "image_embeds")
and "image" not in mm_input_by_modality
):
mm_input_by_modality["image"] = self._parse_and_validate_image_input(
**kwargs
)
if (
input_key in ("pixel_values_videos", "video_embeds")
and "video" not in mm_input_by_modality
):
mm_input_by_modality["video"] = self._parse_and_validate_video_input(
**kwargs
)
return mm_input_by_modality
@staticmethod
def _iter_mm_grid_hw(
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
) -> Iterator[tuple[int, int, int, int]]:
"""Iterate over multimodal features and yield position info.
Args:
input_tokens: List of token IDs in the input sequence.
mm_features: List of multimodal feature specifications containing
image/video data and position information.
video_token_id: Token ID used for video tokens.
vision_start_token_id: Token ID marking the start of a vision sequence.
vision_end_token_id: Token ID marking the end of a vision sequence.
spatial_merge_size: Size of the spatial merge operation used to
compute logical grid dimensions from the original feature grid.
Yields:
offset: Position of the first video/image token in the sequence.
llm_grid_h: Logical grid height (may not match actual token count with EVS).
llm_grid_w: Logical grid width (may not match actual token count with EVS).
actual_num_tokens: Actual number of video/image tokens in the placeholder.
"""
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
yield offset, llm_grid_h, llm_grid_w, llm_grid_h * llm_grid_w
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
for _ in range(t):
# When EVS is enabled, some frames may have 0 video tokens in the
# placeholder. We use `vision_start_token_id` to locate each frame
# since it is always present for every frame.
# We then look for the first `video_token_id` after
# `vision_start_token_id` and before `vision_end_token_id`.
offset = input_tokens.index(vision_start_token_id, offset)
vision_end_offset = input_tokens.index(vision_end_token_id, offset)
try:
actual_num_tokens = 0
video_offset = input_tokens.index(
video_token_id, offset, vision_end_offset
)
# NOTE: looking at the
# `Qwen3VLMultiModalProcessor.get_video_repl` code, we can
# see that we can use the below formula to get the token
# count, since everything in between `video_offset` and
# `vision_end_offset` is populated as `video_token_id`.
# This saves us from manually counting the number tokens
# that match `video_token_id` in between.
actual_num_tokens += vision_end_offset - video_offset
except ValueError:
# No `video_token_id` in this frame (EVS with 0 tokens for
# this frame) -> use `offset + 1`` to move past
# `vision_start_token_id`.
video_offset = offset + 1
yield video_offset, llm_grid_h, llm_grid_w, actual_num_tokens
# Move offset past this frame for next iteration.
offset = vision_end_offset + 1
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
return self._get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=mm_features,
config=self.config,
)
@staticmethod
def _get_mrope_input_positions(
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
config: Qwen3VLConfig,
):
llm_pos_ids_list = []
st = 0
for (
offset,
llm_grid_h,
llm_grid_w,
actual_num_tokens,
) in Qwen3VLForConditionalGeneration._iter_mm_grid_hw(
input_tokens,
mm_features,
video_token_id=config.video_token_id,
vision_start_token_id=config.vision_start_token_id,
vision_end_token_id=config.vision_end_token_id,
spatial_merge_size=config.vision_config.spatial_merge_size,
):
# Skip frames with 0 tokens (EVS placeholder with tokens lumped elsewhere)
if actual_num_tokens == 0:
continue
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
# Check if this is a "lumped placeholder" (all tokens from multiple frames
# assigned to the 0-th frame - see
# `Qwen3VLMultiModalProcessor.get_video_repl`.
expected_tokens_per_frame = llm_grid_h * llm_grid_w
if actual_num_tokens > expected_tokens_per_frame:
# Lumped placeholder: create grid positions for all "logical" frames
# represented.
num_logical_frames = actual_num_tokens // expected_tokens_per_frame
remainder = actual_num_tokens % expected_tokens_per_frame
# Create positions for complete frames.
for _ in range(num_logical_frames):
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(
3, -1
)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1
text_len = 0 # No text between frames within the lump
# Handle remainder tokens if any (partial frame).
# NOTE: this should never be the case. Should we have an assert?
if remainder > 0:
# Create a partial grid - take first 'remainder' positions
full_grid = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
grid_indices = full_grid[:, :remainder]
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
else:
# Normal case: frame has exactly the expected tokens (after actual EVS
# pruning).
grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st = offset + actual_num_tokens
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return torch.from_numpy(llm_positions), mrope_position_delta
def recompute_mrope_positions(
self,
input_ids: list[int],
multimodal_embeddings: MultiModalEmbeddings,
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
"""
Update part of input mrope positions (starting with
num_computed_tokens index). Original mrope_positions are computed
for unpruned sequence and becomes incorrect once pruning occurs,
so once we prune media tokens we should reflect this in the
mrope_positions before we feed it to LLM.
Args:
input_ids: (N,) All input tokens of the prompt containing
entire sequence.
multimodal_embeddings: Tuple of multimodal embeddings that
fits into the prefill chunk that is being processed.
mrope_positions: Existing mrope positions (3, N) for entire
sequence
num_computed_tokens: A number of computed tokens so far.
Returns:
Tuple of (multimodal_embeddings, mrope_positions,
mrope_position_delta).
"""
return self._recompute_mrope_positions(
input_ids=input_ids,
multimodal_embeddings=multimodal_embeddings,
mrope_positions=mrope_positions,
num_computed_tokens=num_computed_tokens,
image_token_id=self.config.image_token_id,
video_token_id=self.config.video_token_id,
vision_start_token_id=self.config.vision_start_token_id,
)
@staticmethod
def _recompute_mrope_positions(
input_ids: list[int],
multimodal_embeddings: MultiModalEmbeddings,
mrope_positions: torch.LongTensor,
num_computed_tokens: int,
vision_start_token_id: int,
image_token_id: int,
video_token_id: int,
) -> tuple[MultiModalEmbeddings, torch.Tensor, int]:
# Device
device = (
multimodal_embeddings[0].device
if len(multimodal_embeddings)
else mrope_positions.device
)
# Tensors
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
mm_embeddings_out = []
mm_embeddings_pos = []
# Strip position information from embeddings (last 5 channels)
# For Qwen3 VL, handle potentially empty frames (from unpacking)
for mm in multimodal_embeddings:
if mm.shape[0] > 0: # Only process non-empty frames
mm_embeddings_out.append(mm[:, :-5])
mm_embeddings_pos.append(mm[:, -5:].permute(1, 0).long())
else:
# Empty frame - keep as is
mm_embeddings_out.append(mm)
# Create empty position tensor with correct shape
mm_embeddings_pos.append(
torch.empty(5, 0, device=device, dtype=torch.long)
)
positions, mrope_positions_delta = recompute_mrope_positions(
input_ids_t,
mm_embeddings_pos,
mrope_positions,
num_computed_tokens,
vision_start_token_id,
image_token_id,
video_token_id,
)
return tuple(mm_embeddings_out), positions, mrope_positions_delta
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
multimodal_embeddings: list[torch.Tensor] = []
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
image_embeddings = self._process_image_input(multimodal_input)
image_embeddings = self._postprocess_image_embeds_evs(
image_embeddings, multimodal_input
)
multimodal_embeddings.extend(image_embeddings)
if modality == "video":
video_embeddings = self._process_video_input(multimodal_input)
if self.is_multimodal_pruning_enabled:
video_embeddings = self._postprocess_video_embeds_evs(
video_embeddings, multimodal_input
)
multimodal_embeddings.extend(video_embeddings)
embeddings_tuple = tuple(multimodal_embeddings)
return embeddings_tuple
def _compute_deepstack_embeds(
self,
inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
is_multimodal: torch.Tensor,
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
visual_lens = [len(x) for x in multimodal_embeddings]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
(
multimodal_embeddings_main,
multimodal_embeddings_multiscale,
) = torch.split(
multimodal_embeddings_cat,
[self.visual_dim, self.multiscale_dim],
dim=-1,
)
multimodal_embeddings = torch.split(
multimodal_embeddings_main, visual_lens, dim=0
)
multimodal_embeddings_multiscale = torch.split(
multimodal_embeddings_multiscale, visual_lens, dim=0
)
deepstack_input_embeds = inputs_embeds.new_zeros(
inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1)
)
deepstack_input_embeds = _merge_multimodal_embeddings(
inputs_embeds=deepstack_input_embeds,
multimodal_embeddings=multimodal_embeddings_multiscale,
is_multimodal=is_multimodal,
)
deepstack_input_embeds = deepstack_input_embeds.view(
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim
)
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
return deepstack_input_embeds, multimodal_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
is_multimodal = _require_is_multimodal(is_multimodal)
if self.use_deepstack:
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
else:
deepstack_input_embeds = None
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
if deepstack_input_embeds is not None:
self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for Qwen3VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen3VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
intermediate_tensors: Intermediate tensors from previous pipeline
stages.
inputs_embeds: Pre-computed input embeddings.
**kwargs: Additional keyword arguments including:
- pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
- image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
LLM. `None` if no images are passed.
- pixel_values_videos: Pixel values of videos to be fed to a
model. `None` if no videos are passed.
- video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
LLM. `None` if no videos are passed.
"""
if intermediate_tensors is not None:
inputs_embeds = None
if inputs_embeds is not None and get_pp_group().is_first_rank:
deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0)
)
else:
deepstack_input_embeds = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
# args for deepstack
deepstack_input_embeds=deepstack_input_embeds,
)
if inputs_embeds is not None and get_pp_group().is_first_rank:
self._clear_deepstack_input_embeds(inputs_embeds.size(0))
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector=["visual.merger", "visual.deepstack_merger_list"],
tower_model="visual.",
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.config
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.config
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@lru_cache
def _cached_tensor(x, device) -> torch.Tensor:
return torch.tensor(x, device=device)