2026-01-13 01:39:02 +09:00
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
from typing import Annotated, Literal, TypeAlias
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import regex as re
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
from timm.layers import LayerNorm2d
|
|
|
|
|
|
from timm.layers.pos_embed import resample_abs_pos_embed
|
|
|
|
|
|
from timm.models.regnet import RegStage
|
|
|
|
|
|
from torch import nn
|
|
|
|
|
|
from transformers import BatchFeature
|
|
|
|
|
|
from transformers.modeling_outputs import BaseModelOutput
|
|
|
|
|
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
|
|
|
|
|
|
|
|
|
|
|
|
from vllm.config import VllmConfig
|
|
|
|
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
|
|
|
|
from vllm.multimodal.inputs import (
|
|
|
|
|
|
MultiModalDataDict,
|
|
|
|
|
|
MultiModalFieldConfig,
|
|
|
|
|
|
MultiModalKwargsItems,
|
|
|
|
|
|
)
|
|
|
|
|
|
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
|
|
|
|
|
from vllm.multimodal.processing import (
|
2026-01-14 23:25:31 +08:00
|
|
|
|
BaseDummyInputsBuilder,
|
2026-01-13 01:39:02 +09:00
|
|
|
|
BaseMultiModalProcessor,
|
|
|
|
|
|
BaseProcessingInfo,
|
|
|
|
|
|
PromptReplacement,
|
|
|
|
|
|
PromptUpdate,
|
|
|
|
|
|
)
|
|
|
|
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
|
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
|
|
|
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
|
|
|
|
|
|
|
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|
|
|
|
|
from .qwen2_vl import Qwen2VisionTransformer
|
|
|
|
|
|
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
|
|
|
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KananaVImagePixelInputs(TensorSchema):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Dimensions:
|
|
|
|
|
|
- np: The total number of patches over all images in the batch
|
|
|
|
|
|
- cps: Number of channels * patch_size * patch_size
|
|
|
|
|
|
- ni: Number of images
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
type: Literal["pixel_values"]
|
|
|
|
|
|
|
|
|
|
|
|
pixel_values: Annotated[
|
|
|
|
|
|
torch.Tensor,
|
|
|
|
|
|
TensorShape("np", "cps"),
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
vision_grid_thw: Annotated[
|
|
|
|
|
|
torch.Tensor,
|
|
|
|
|
|
TensorShape("ni", 3),
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
KananaVImageInputs: TypeAlias = KananaVImagePixelInputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_pos_embeds(
|
|
|
|
|
|
config: Qwen2VLVisionConfig,
|
|
|
|
|
|
num_input_tokens: int,
|
|
|
|
|
|
vision_hidden_size: int,
|
|
|
|
|
|
) -> nn.Parameter | None:
|
|
|
|
|
|
"""Build positional embeddings for the visual encoder output."""
|
|
|
|
|
|
if config.pos_emb:
|
|
|
|
|
|
pos_emb = nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
|
|
|
|
|
|
nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
|
|
|
|
|
|
else:
|
|
|
|
|
|
pos_emb = None
|
|
|
|
|
|
|
|
|
|
|
|
return pos_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_mlp(
|
|
|
|
|
|
depth: int,
|
|
|
|
|
|
hidden_size: int,
|
|
|
|
|
|
output_hidden_size: int,
|
|
|
|
|
|
) -> nn.Sequential:
|
|
|
|
|
|
"""Simple SiLU-activated MLP used as a projector readout."""
|
|
|
|
|
|
layers = [nn.Linear(hidden_size, output_hidden_size)]
|
|
|
|
|
|
for _ in range(1, depth):
|
|
|
|
|
|
layers.append(nn.SiLU())
|
|
|
|
|
|
layers.append(nn.Linear(output_hidden_size, output_hidden_size))
|
|
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PatchMerge(nn.Module):
|
|
|
|
|
|
"""Merge neighboring patches spatially to reduce resolution."""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, merge_size: int) -> None:
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
self.merge_size = merge_size
|
|
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
|
self,
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
channel_last: bool = False,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
"""Merge patches by `merge_size x merge_size`."""
|
|
|
|
|
|
if channel_last:
|
|
|
|
|
|
x = rearrange(x, "B H W D -> B D H W")
|
|
|
|
|
|
_, _, H, W = x.shape
|
|
|
|
|
|
merged_x = rearrange(
|
|
|
|
|
|
x,
|
|
|
|
|
|
"B D (H h2) (W w2) -> B (D h2 w2) H W",
|
|
|
|
|
|
h2=self.merge_size,
|
|
|
|
|
|
w2=self.merge_size,
|
|
|
|
|
|
)
|
|
|
|
|
|
return merged_x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicCAbstractor(nn.Module):
|
|
|
|
|
|
"""Dynamic C-Abstractor based on RegNet blocks."""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
config: Qwen2VLVisionConfig,
|
|
|
|
|
|
num_input_tokens: int,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
assert hasattr(config, "merge_size"), "merge_size must be provided."
|
|
|
|
|
|
self.config = config
|
|
|
|
|
|
self.merge_size = config.merge_size
|
|
|
|
|
|
self.pos_emb_size = config.pos_emb_size
|
|
|
|
|
|
if num_input_tokens == -1:
|
|
|
|
|
|
num_input_tokens = config.pos_emb_size
|
|
|
|
|
|
self.num_input_tokens = num_input_tokens
|
|
|
|
|
|
self.pos_emb = build_pos_embeds(
|
|
|
|
|
|
config, num_input_tokens, config.encoder_hidden_size
|
|
|
|
|
|
)
|
|
|
|
|
|
self.build_net()
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_state_dict(self, state_dict, *args, **kwargs) -> None:
|
|
|
|
|
|
if not state_dict:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if self.pos_emb is not None:
|
|
|
|
|
|
key_re = re.compile(r"[\w,.]*abstractor[\w,.]*pos_emb")
|
|
|
|
|
|
pos_emb_key = None
|
|
|
|
|
|
for key in state_dict:
|
|
|
|
|
|
if key_re.match(key):
|
|
|
|
|
|
pos_emb_key = key
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
assert pos_emb_key is not None
|
|
|
|
|
|
# update old ckpt compatible with current code
|
|
|
|
|
|
pos_emb = state_dict[pos_emb_key]
|
|
|
|
|
|
if pos_emb.size(1) == self.pos_emb.size(1) + 1:
|
|
|
|
|
|
# remove obsolete first pos emb (for cls token originally)
|
|
|
|
|
|
state_dict[pos_emb_key] = pos_emb[:, 1:]
|
|
|
|
|
|
|
|
|
|
|
|
super()._load_from_state_dict(state_dict, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def build_net(self) -> None:
|
|
|
|
|
|
encoder_hidden_size = self.config.encoder_hidden_size
|
|
|
|
|
|
hidden_size = self.config.hidden_size
|
|
|
|
|
|
output_hidden_size = self.config.output_hidden_size
|
|
|
|
|
|
depth = self.config.depth
|
|
|
|
|
|
mlp_depth = self.config.mlp_depth
|
|
|
|
|
|
|
|
|
|
|
|
RegBlock = partial(
|
|
|
|
|
|
RegStage,
|
|
|
|
|
|
stride=1,
|
|
|
|
|
|
dilation=1,
|
|
|
|
|
|
act_layer=nn.SiLU,
|
|
|
|
|
|
norm_layer=LayerNorm2d,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
s1 = RegBlock(
|
|
|
|
|
|
depth,
|
|
|
|
|
|
encoder_hidden_size,
|
|
|
|
|
|
hidden_size,
|
|
|
|
|
|
)
|
|
|
|
|
|
sampler = PatchMerge(merge_size=self.merge_size)
|
|
|
|
|
|
s2 = RegBlock(
|
|
|
|
|
|
depth,
|
|
|
|
|
|
self.merge_size**2 * hidden_size,
|
|
|
|
|
|
hidden_size,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if depth:
|
|
|
|
|
|
self.net = nn.ModuleList([s1, sampler, s2])
|
|
|
|
|
|
self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.net = sampler
|
|
|
|
|
|
self.readout = build_mlp(mlp_depth, encoder_hidden_size, output_hidden_size)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
|
self,
|
|
|
|
|
|
flattened_visual_embeds: torch.Tensor,
|
|
|
|
|
|
grid_thw: torch.Tensor,
|
|
|
|
|
|
**unused_kwargs: object,
|
|
|
|
|
|
) -> BaseModelOutput:
|
|
|
|
|
|
"""Apply the dynamic abstractor over flattened visual embeddings."""
|
|
|
|
|
|
n_token_loc = torch.prod(grid_thw, dim=1)
|
|
|
|
|
|
split_visual_embeds = torch.split(flattened_visual_embeds, n_token_loc.tolist())
|
|
|
|
|
|
|
|
|
|
|
|
flattened_visual_embeds = []
|
|
|
|
|
|
for _visual_embeds, _grid_thw in zip(split_visual_embeds, grid_thw):
|
|
|
|
|
|
T, H, W = _grid_thw
|
|
|
|
|
|
assert T == 1, "T must be 1. Video is not supported yet."
|
|
|
|
|
|
reshaped_visual_embeds = rearrange(
|
|
|
|
|
|
_visual_embeds, "(t h w) d -> 1 t h w d", t=T, h=H, w=W
|
|
|
|
|
|
)
|
|
|
|
|
|
# remove temporal dim
|
|
|
|
|
|
reshaped_visual_embeds = reshaped_visual_embeds[:, 0]
|
|
|
|
|
|
|
|
|
|
|
|
if self.pos_emb is not None:
|
|
|
|
|
|
# interpolate pos emb and add to visual embeds
|
|
|
|
|
|
_local_pos_emb = resample_abs_pos_embed(
|
|
|
|
|
|
posemb=self.pos_emb,
|
|
|
|
|
|
old_size=tuple([int(self.pos_emb_size**0.5)] * 2),
|
|
|
|
|
|
new_size=(H, W),
|
|
|
|
|
|
num_prefix_tokens=0,
|
|
|
|
|
|
)
|
|
|
|
|
|
_local_pos_emb = rearrange(
|
|
|
|
|
|
_local_pos_emb,
|
|
|
|
|
|
"1 (h w) d -> 1 h w d",
|
|
|
|
|
|
h=H,
|
|
|
|
|
|
w=W,
|
|
|
|
|
|
)
|
|
|
|
|
|
reshaped_visual_embeds = reshaped_visual_embeds + _local_pos_emb
|
|
|
|
|
|
|
|
|
|
|
|
reshaped_visual_embeds = self._forward(
|
|
|
|
|
|
reshaped_visual_embeds,
|
|
|
|
|
|
input_size=(H, W),
|
|
|
|
|
|
)
|
|
|
|
|
|
flattened_visual_embeds.append(reshaped_visual_embeds)
|
|
|
|
|
|
reshaped_visual_embeds = torch.cat(flattened_visual_embeds, dim=0)
|
|
|
|
|
|
return BaseModelOutput(last_hidden_state=reshaped_visual_embeds)
|
|
|
|
|
|
|
|
|
|
|
|
def _forward(
|
|
|
|
|
|
self,
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
input_size: tuple[int, int],
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
h, w = input_size
|
|
|
|
|
|
x = rearrange(x, "1 h w d -> 1 d h w", h=h, w=w)
|
|
|
|
|
|
if self.config.depth:
|
|
|
|
|
|
x = self.net[0](x)
|
|
|
|
|
|
x = self.net[1](x)
|
|
|
|
|
|
x = self.net[2](x)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# When depth=0, self.net is a single PatchMerge module
|
|
|
|
|
|
x = self.net(x)
|
|
|
|
|
|
x = rearrange(x, "1 d h w -> (h w) d")
|
|
|
|
|
|
x = self.readout(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomQwen2VLVE(Qwen2VisionTransformer):
|
|
|
|
|
|
"""Thin wrapper around the Qwen2-VL used as a vision encoder.
|
|
|
|
|
|
|
|
|
|
|
|
This mirrors the original HF-based vision encoder used in Kanana-V, but
|
|
|
|
|
|
reuses vLLM's optimized `Qwen2VisionTransformer` building blocks.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: Qwen2VLVisionConfig) -> None:
|
|
|
|
|
|
super().__init__(
|
|
|
|
|
|
vision_config=config,
|
|
|
|
|
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
|
|
|
|
|
quant_config=None,
|
|
|
|
|
|
prefix="",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Kanana-V uses its own projector/abstractor instead of the Qwen2
|
|
|
|
|
|
# built-in patch merger, so we drop the merger module to keep the
|
|
|
|
|
|
# parameter set compatible with the original checkpoint.
|
|
|
|
|
|
if hasattr(self, "merger"):
|
|
|
|
|
|
del self.merger
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def _from_config(cls, config: Qwen2VLVisionConfig) -> "CustomQwen2VLVE":
|
|
|
|
|
|
"""Drop-in replacement for the HF `_from_config` constructor."""
|
|
|
|
|
|
return cls(config)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
|
self,
|
|
|
|
|
|
pixel_values: torch.Tensor,
|
|
|
|
|
|
grid_thw: torch.Tensor,
|
|
|
|
|
|
output_hidden_states: bool | None = None,
|
|
|
|
|
|
return_dict: bool | None = None,
|
|
|
|
|
|
) -> tuple | BaseModelOutput:
|
|
|
|
|
|
"""Run the vision transformer and optionally return intermediate states.
|
|
|
|
|
|
|
|
|
|
|
|
Unlike the base `Qwen2VisionTransformer`, this wrapper exposes the
|
|
|
|
|
|
pre-merger patch-level representations and a HF-style `BaseModelOutput`
|
|
|
|
|
|
so that the existing projector / abstractor code can be reused.
|
|
|
|
|
|
"""
|
|
|
|
|
|
assert return_dict, "Only return_dict=True is supported."
|
|
|
|
|
|
|
|
|
|
|
|
# Patchify
|
|
|
|
|
|
x = pixel_values.to(device=self.device, dtype=self.dtype)
|
|
|
|
|
|
x = self.patch_embed(x) # (num_patches, embed_dim)
|
|
|
|
|
|
|
|
|
|
|
|
# Prepare grid and rotary embeddings – mirror base implementation.
|
|
|
|
|
|
if isinstance(grid_thw, list):
|
|
|
|
|
|
grid_thw_list = grid_thw
|
|
|
|
|
|
grid_thw_np = np.array(grid_thw, dtype=np.int32)
|
|
|
|
|
|
else:
|
|
|
|
|
|
grid_thw_list = grid_thw.tolist()
|
|
|
|
|
|
grid_thw_np = grid_thw.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
|
|
|
|
|
|
|
|
|
|
|
# Compute cu_seqlens in numpy then move to device, same as base model.
|
|
|
|
|
|
cu_seqlens = np.repeat(
|
|
|
|
|
|
grid_thw_np[:, 1] * grid_thw_np[:, 2],
|
|
|
|
|
|
grid_thw_np[:, 0],
|
|
|
|
|
|
).cumsum(axis=0, dtype=np.int32)
|
|
|
|
|
|
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
|
|
|
|
|
cu_seqlens = torch.from_numpy(cu_seqlens).to(
|
|
|
|
|
|
self.device,
|
|
|
|
|
|
non_blocking=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Shape to (S, B, D) with batch dimension 1 as expected by the blocks.
|
|
|
|
|
|
x = x.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
# Pre-compute seqlens for attention backend.
|
|
|
|
|
|
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
|
|
|
|
|
|
|
|
|
|
|
|
encoder_states = () if output_hidden_states else None
|
|
|
|
|
|
|
|
|
|
|
|
for blk in self.blocks:
|
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
|
# Store patch-level states (S, D).
|
|
|
|
|
|
encoder_states = encoder_states + (x.squeeze(1),)
|
|
|
|
|
|
|
|
|
|
|
|
x = blk(
|
|
|
|
|
|
x,
|
|
|
|
|
|
cu_seqlens=cu_seqlens,
|
|
|
|
|
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
|
|
|
|
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
|
|
|
|
|
max_seqlen=max_seqlen,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Final hidden state at patch level (S, D).
|
|
|
|
|
|
hidden_states = x.squeeze(1)
|
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
|
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
|
|
|
|
|
return BaseModelOutput(
|
|
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
|
|
hidden_states=encoder_states,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def get_num_tokens(self) -> int:
|
|
|
|
|
|
# Not used in the current Kanana-V pipeline, kept for API compatibility.
|
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KananaVProcessingInfo(BaseProcessingInfo):
|
|
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
|
|
|
|
return {"image": None}
|
|
|
|
|
|
|
|
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize:
|
|
|
|
|
|
max_image_size, _ = self._get_vision_info(
|
|
|
|
|
|
image_width=9999,
|
|
|
|
|
|
image_height=9999,
|
|
|
|
|
|
num_frames=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
return max_image_size
|
|
|
|
|
|
|
|
|
|
|
|
def _get_vision_info(
|
|
|
|
|
|
self,
|
|
|
|
|
|
*,
|
|
|
|
|
|
image_width: int,
|
|
|
|
|
|
image_height: int,
|
|
|
|
|
|
num_frames: int = 1,
|
|
|
|
|
|
do_resize: bool = True,
|
|
|
|
|
|
) -> tuple[ImageSize, int]:
|
|
|
|
|
|
image_processor = self.ctx.get_hf_processor().image_processor
|
|
|
|
|
|
smart_resize = resolve_obj_by_qualname(
|
|
|
|
|
|
f"{type(image_processor).__module__}.smart_resize"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
hf_config = self.get_hf_config()
|
|
|
|
|
|
vision_config = hf_config.vision_config
|
|
|
|
|
|
patch_size = vision_config.patch_size
|
|
|
|
|
|
merge_size = vision_config.spatial_merge_size
|
|
|
|
|
|
temporal_patch_size = vision_config.temporal_patch_size
|
|
|
|
|
|
|
|
|
|
|
|
if do_resize:
|
|
|
|
|
|
resized_height, resized_width = smart_resize(
|
|
|
|
|
|
height=image_height,
|
|
|
|
|
|
width=image_width,
|
|
|
|
|
|
factor=patch_size * merge_size,
|
|
|
|
|
|
min_pixels=image_processor.min_pixels,
|
|
|
|
|
|
max_pixels=image_processor.max_pixels,
|
|
|
|
|
|
)
|
|
|
|
|
|
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
|
|
|
|
|
|
else:
|
|
|
|
|
|
preprocessed_size = ImageSize(width=image_width, height=image_height)
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE: Frames are padded to be divisible by `temporal_patch_size`
|
|
|
|
|
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
|
|
|
|
|
|
padded_num_frames = num_frames + num_frames % temporal_patch_size
|
|
|
|
|
|
|
|
|
|
|
|
grid_t = max(padded_num_frames // temporal_patch_size, 1)
|
|
|
|
|
|
grid_h = preprocessed_size.height // patch_size
|
|
|
|
|
|
grid_w = preprocessed_size.width // patch_size
|
|
|
|
|
|
|
|
|
|
|
|
num_patches = grid_t * grid_h * grid_w
|
|
|
|
|
|
num_vision_tokens = num_patches // (merge_size**2)
|
|
|
|
|
|
|
|
|
|
|
|
return preprocessed_size, num_vision_tokens
|
|
|
|
|
|
|
|
|
|
|
|
def get_mm_max_tokens_per_item(
|
|
|
|
|
|
self,
|
|
|
|
|
|
seq_len: int,
|
|
|
|
|
|
mm_counts: Mapping[str, int],
|
|
|
|
|
|
) -> Mapping[str, int]:
|
|
|
|
|
|
target_width, target_height = self.get_image_size_with_most_features()
|
|
|
|
|
|
num_vision_tokens = self._get_vision_info(
|
|
|
|
|
|
image_width=target_width,
|
|
|
|
|
|
image_height=target_height,
|
|
|
|
|
|
num_frames=1,
|
|
|
|
|
|
)[1]
|
|
|
|
|
|
return {"image": num_vision_tokens}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KananaVDummyInputsBuilder(BaseDummyInputsBuilder[KananaVProcessingInfo]):
|
|
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
|
|
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
|
|
return "<image>" * num_images
|
|
|
|
|
|
|
|
|
|
|
|
def get_dummy_mm_data(
|
|
|
|
|
|
self,
|
|
|
|
|
|
seq_len: int,
|
|
|
|
|
|
mm_counts: Mapping[str, int],
|
|
|
|
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
2026-02-12 01:37:40 +08:00
|
|
|
|
mm_processor_kwargs: Mapping[str, object] | None = None,
|
2026-01-13 01:39:02 +09:00
|
|
|
|
) -> MultiModalDataDict:
|
|
|
|
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
|
|
return {
|
|
|
|
|
|
"image": self._get_dummy_images(
|
|
|
|
|
|
width=9999, height=9999, num_images=num_images
|
|
|
|
|
|
),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KananaVMultiModalProcessor(BaseMultiModalProcessor[KananaVProcessingInfo]):
|
|
|
|
|
|
"""vLLM multimodal processor for Kanana-V (text + image)."""
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def media_token_id(self) -> int:
|
|
|
|
|
|
return self.info.get_hf_config().text_config.eos_token_id + 1
|
|
|
|
|
|
|
|
|
|
|
|
def _call_hf_processor(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompt: str,
|
|
|
|
|
|
mm_data: Mapping[str, object],
|
|
|
|
|
|
mm_kwargs: Mapping[str, object],
|
|
|
|
|
|
tok_kwargs: Mapping[str, object],
|
|
|
|
|
|
) -> BatchFeature:
|
|
|
|
|
|
"""Run the underlying HF processor on text and image data."""
|
|
|
|
|
|
# Text-only input is handled as a special case here.
|
|
|
|
|
|
if not mm_data or not mm_data.get("images", []):
|
|
|
|
|
|
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
|
|
|
|
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
|
|
|
|
|
|
|
|
|
|
|
# Images
|
|
|
|
|
|
image_inputs = mm_data.get("images", [])
|
|
|
|
|
|
pixel_sizes = []
|
|
|
|
|
|
if not isinstance(image_inputs[0], Image.Image):
|
|
|
|
|
|
image_inputs = [Image.fromarray(image) for image in image_inputs]
|
|
|
|
|
|
|
|
|
|
|
|
image_processor = self.info.get_hf_processor().image_processor
|
|
|
|
|
|
processor_output = [image_processor(image) for image in image_inputs]
|
|
|
|
|
|
pixel_values = [o["pixel_values"] for o in processor_output]
|
|
|
|
|
|
image_meta = [o["image_meta"] for o in processor_output]
|
|
|
|
|
|
# list of dict -> dict of list
|
|
|
|
|
|
image_meta = {k: [d[k] for d in image_meta] for k in image_meta[0]}
|
|
|
|
|
|
|
|
|
|
|
|
for pixel_value in pixel_values:
|
|
|
|
|
|
pixel_sizes.append(pixel_value.shape[0])
|
|
|
|
|
|
# flattened pixel_values for single example (already includes batch dim)
|
|
|
|
|
|
pixel_values = torch.concat(pixel_values, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = self.info.get_tokenizer()
|
|
|
|
|
|
media_token = tokenizer.convert_ids_to_tokens([self.media_token_id])[0]
|
|
|
|
|
|
prompt_replaced = prompt.replace("<image>", media_token)
|
|
|
|
|
|
input_ids = tokenizer.encode(prompt_replaced)
|
|
|
|
|
|
input_ids = torch.tensor(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
# Ensure HF output is consistent with vLLM prompt-update expectations:
|
|
|
|
|
|
# if the HF tokenizer emits exactly 1 placeholder token per image, expand
|
|
|
|
|
|
# it to `T*H*W` placeholder tokens per image so placeholder detection works.
|
|
|
|
|
|
num_images = len(image_inputs)
|
|
|
|
|
|
image_token_thw = torch.tensor(image_meta["image_token_thw"])
|
|
|
|
|
|
per_image_token_counts = image_token_thw.prod(dim=1).tolist()
|
|
|
|
|
|
expected_total = int(sum(int(x) for x in per_image_token_counts))
|
|
|
|
|
|
|
|
|
|
|
|
n_placeholders = int((input_ids == self.media_token_id).sum().item())
|
|
|
|
|
|
if n_placeholders == num_images and expected_total != num_images:
|
|
|
|
|
|
expanded: list[int] = []
|
|
|
|
|
|
img_i = 0
|
|
|
|
|
|
for tok in input_ids.tolist():
|
|
|
|
|
|
if tok == self.media_token_id and img_i < num_images:
|
|
|
|
|
|
expanded.extend(
|
|
|
|
|
|
[self.media_token_id] * int(per_image_token_counts[img_i])
|
|
|
|
|
|
)
|
|
|
|
|
|
img_i += 1
|
|
|
|
|
|
else:
|
|
|
|
|
|
expanded.append(tok)
|
|
|
|
|
|
input_ids = input_ids.new_tensor(expanded)
|
|
|
|
|
|
|
|
|
|
|
|
combined_outputs = dict(
|
|
|
|
|
|
# Add batch dimension to input_ids.
|
|
|
|
|
|
input_ids=input_ids.unsqueeze(0),
|
|
|
|
|
|
pixel_values=pixel_values,
|
|
|
|
|
|
vision_grid_thw=torch.tensor(image_meta["vision_grid_thw"]),
|
|
|
|
|
|
image_token_thw=torch.tensor(image_meta["image_token_thw"]),
|
|
|
|
|
|
pixel_sizes=torch.tensor(pixel_sizes),
|
|
|
|
|
|
)
|
|
|
|
|
|
return BatchFeature(combined_outputs, tensor_type="pt")
|
|
|
|
|
|
|
|
|
|
|
|
def _get_prompt_updates(
|
|
|
|
|
|
self,
|
|
|
|
|
|
mm_items: MultiModalDataItems,
|
|
|
|
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
|
|
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
|
|
|
|
) -> Sequence[PromptUpdate]:
|
|
|
|
|
|
def get_replacement(idx: int) -> Sequence[int]:
|
|
|
|
|
|
out_item = out_mm_kwargs["image"][idx]
|
|
|
|
|
|
image_token_thw = out_item["image_token_thw"].data
|
|
|
|
|
|
assert isinstance(image_token_thw, torch.Tensor)
|
|
|
|
|
|
|
|
|
|
|
|
num_tokens = int(image_token_thw.prod().item())
|
|
|
|
|
|
return [self.media_token_id] * num_tokens
|
|
|
|
|
|
|
|
|
|
|
|
return [
|
|
|
|
|
|
PromptReplacement(
|
|
|
|
|
|
modality="image",
|
|
|
|
|
|
target="<image>",
|
|
|
|
|
|
replacement=get_replacement,
|
|
|
|
|
|
),
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def _get_mm_fields_config(
|
|
|
|
|
|
self,
|
|
|
|
|
|
hf_inputs: BatchFeature,
|
|
|
|
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
|
|
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
|
|
|
|
pixel_sizes = hf_inputs.get("pixel_sizes", torch.empty(0))
|
|
|
|
|
|
|
|
|
|
|
|
mm_fields_config = dict(
|
|
|
|
|
|
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", pixel_sizes),
|
|
|
|
|
|
vision_grid_thw=MultiModalFieldConfig.batched("image"),
|
|
|
|
|
|
image_token_thw=MultiModalFieldConfig.batched("image"),
|
|
|
|
|
|
)
|
|
|
|
|
|
return mm_fields_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
|
|
|
|
KananaVMultiModalProcessor,
|
|
|
|
|
|
info=KananaVProcessingInfo,
|
|
|
|
|
|
dummy_inputs=KananaVDummyInputsBuilder,
|
|
|
|
|
|
)
|
|
|
|
|
|
class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
|
|
|
|
if modality.startswith("image"):
|
|
|
|
|
|
return "<image>"
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"Unsupported modality: {modality}")
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
2026-01-20 18:21:56 +08:00
|
|
|
|
with self._mark_tower_model(vllm_config, "image"):
|
|
|
|
|
|
self.vision_model = CustomQwen2VLVE._from_config(config.vision_config)
|
|
|
|
|
|
self.abstractor = DynamicCAbstractor(
|
|
|
|
|
|
config.projector_config,
|
|
|
|
|
|
num_input_tokens=self.vision_model.get_num_tokens(),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
with self._mark_language_model(vllm_config):
|
|
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
|
|
|
|
vllm_config=vllm_config,
|
|
|
|
|
|
hf_config=config.text_config,
|
|
|
|
|
|
prefix=maybe_prefix(prefix, "model"),
|
|
|
|
|
|
architectures=["LlamaForCausalLM"],
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-13 01:39:02 +09:00
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
|
|
|
|
self.language_model.make_empty_intermediate_tensors
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_and_validate_image_input(
|
|
|
|
|
|
self, **kwargs: object
|
|
|
|
|
|
) -> KananaVImageInputs | None:
|
|
|
|
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
|
|
|
|
vision_grid_thw = kwargs.pop("vision_grid_thw", None)
|
|
|
|
|
|
|
|
|
|
|
|
if pixel_values is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
if vision_grid_thw is None:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
"vision_grid_thw is required when pixel_values is provided"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Normalize pixel_values to 2D tensor (num_patches, channels*patch*patch)
|
|
|
|
|
|
if isinstance(pixel_values, torch.Tensor):
|
|
|
|
|
|
if pixel_values.ndim == 2:
|
|
|
|
|
|
pass # Already in expected shape
|
|
|
|
|
|
elif pixel_values.ndim == 3:
|
|
|
|
|
|
pixel_values = pixel_values.flatten(0, 1)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
f"pixel_values should be 2D or batched 3D tensor. "
|
|
|
|
|
|
f"Got ndim: {pixel_values.ndim} "
|
|
|
|
|
|
f"(shape={pixel_values.shape})"
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
pixel_values = torch.concat(pixel_values)
|
|
|
|
|
|
|
|
|
|
|
|
# Normalize vision_grid_thw to 2D tensor (num_images, 3)
|
|
|
|
|
|
if isinstance(vision_grid_thw, torch.Tensor):
|
|
|
|
|
|
if vision_grid_thw.ndim == 3:
|
|
|
|
|
|
vision_grid_thw = vision_grid_thw.flatten(0, 1)
|
|
|
|
|
|
else:
|
|
|
|
|
|
vision_grid_thw = torch.concat(vision_grid_thw)
|
|
|
|
|
|
|
|
|
|
|
|
return KananaVImagePixelInputs(
|
|
|
|
|
|
type="pixel_values",
|
|
|
|
|
|
pixel_values=pixel_values,
|
|
|
|
|
|
vision_grid_thw=vision_grid_thw,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _process_image_input(self, image_input: KananaVImageInputs) -> torch.Tensor:
|
|
|
|
|
|
pixel_values = image_input["pixel_values"]
|
|
|
|
|
|
vision_grid_thw = image_input["vision_grid_thw"]
|
|
|
|
|
|
|
|
|
|
|
|
image_metas = {"vision_grid_thw": vision_grid_thw}
|
|
|
|
|
|
visual_embeds = self.forward_and_project_vision(pixel_values, image_metas)
|
|
|
|
|
|
|
|
|
|
|
|
merge_size = self.abstractor.merge_size
|
|
|
|
|
|
batch_size = vision_grid_thw.size(0)
|
|
|
|
|
|
multi_modal_embeddings: tuple[torch.Tensor, ...] = ()
|
|
|
|
|
|
sample_index = 0
|
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
|
t, h, w = (
|
|
|
|
|
|
vision_grid_thw[i][0],
|
|
|
|
|
|
vision_grid_thw[i][1] // merge_size,
|
|
|
|
|
|
vision_grid_thw[i][2] // merge_size,
|
|
|
|
|
|
)
|
|
|
|
|
|
num_tokens = t * h * w
|
|
|
|
|
|
visual_embed = visual_embeds[sample_index : sample_index + num_tokens]
|
|
|
|
|
|
multi_modal_embeddings += (visual_embed,)
|
|
|
|
|
|
sample_index += num_tokens
|
|
|
|
|
|
|
|
|
|
|
|
return multi_modal_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
def _get_visual_feature_at(
|
|
|
|
|
|
self,
|
|
|
|
|
|
v_output: Sequence[torch.Tensor],
|
|
|
|
|
|
layer_index: int | Sequence[int],
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
if isinstance(layer_index, (list, tuple)):
|
|
|
|
|
|
visual_features = torch.stack(v_output, dim=1)[
|
|
|
|
|
|
:, layer_index
|
|
|
|
|
|
] # [B, n_scales, L, dim]
|
|
|
|
|
|
else:
|
|
|
|
|
|
visual_features = v_output[layer_index] # [B, L, dim]
|
|
|
|
|
|
return visual_features
|
|
|
|
|
|
|
|
|
|
|
|
def forward_vision(
|
|
|
|
|
|
self,
|
|
|
|
|
|
pixel_values: torch.Tensor,
|
|
|
|
|
|
image_metas: dict | None = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
vision_model_args = {
|
|
|
|
|
|
"pixel_values": pixel_values,
|
|
|
|
|
|
"return_dict": True,
|
|
|
|
|
|
"output_hidden_states": True,
|
|
|
|
|
|
"grid_thw": image_metas["vision_grid_thw"],
|
|
|
|
|
|
}
|
|
|
|
|
|
v_outputs = self.vision_model(**vision_model_args)
|
|
|
|
|
|
layer_index = self.config.projector_config.feature_layer_index
|
|
|
|
|
|
visual_features = self._get_visual_feature_at(
|
|
|
|
|
|
v_outputs.hidden_states, layer_index
|
|
|
|
|
|
)
|
|
|
|
|
|
return visual_features
|
|
|
|
|
|
|
|
|
|
|
|
def forward_projector(
|
|
|
|
|
|
self,
|
|
|
|
|
|
visual_features: torch.Tensor,
|
|
|
|
|
|
image_metas: dict | None = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
visual_embeds = self.abstractor(
|
|
|
|
|
|
visual_features,
|
|
|
|
|
|
grid_thw=image_metas["vision_grid_thw"],
|
|
|
|
|
|
)["last_hidden_state"]
|
|
|
|
|
|
return visual_embeds
|
|
|
|
|
|
|
|
|
|
|
|
def forward_and_project_vision(
|
|
|
|
|
|
self,
|
|
|
|
|
|
pixel_values: torch.Tensor,
|
|
|
|
|
|
image_metas: dict | None = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
assert pixel_values is not None
|
|
|
|
|
|
visual_features = self.forward_vision(pixel_values, image_metas=image_metas)
|
|
|
|
|
|
visual_embeds = self.forward_projector(visual_features, image_metas=image_metas)
|
|
|
|
|
|
return visual_embeds
|
|
|
|
|
|
|
|
|
|
|
|
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
|
|
|
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
|
|
|
|
if image_input is None:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
return self._process_image_input(image_input)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
|
self,
|
2026-01-26 22:02:10 +08:00
|
|
|
|
input_ids: torch.Tensor | None,
|
2026-01-13 01:39:02 +09:00
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
|
|
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
):
|
|
|
|
|
|
if intermediate_tensors is not None:
|
|
|
|
|
|
inputs_embeds = None
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = self.language_model(
|
|
|
|
|
|
input_ids=input_ids,
|
|
|
|
|
|
positions=positions,
|
|
|
|
|
|
intermediate_tensors=intermediate_tensors,
|
|
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
return self.language_model.compute_logits(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
|
|
|
|
loader = AutoWeightsLoader(self)
|
|
|
|
|
|
return loader.load_weights(weights)
|