diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index c2734ce11..f6ac29877 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -680,6 +680,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ |
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + IE+ + VE+ | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ |
+| `KananaVForConditionalGeneration` | Kanana-V | T + I+ | `kakaocorp/kanana-1.5-v-3b-instruct`, etc. | | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index df205a67d..2d8c6081e 100755
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -769,6 +769,33 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
)
+# Kanana-V
+def run_kanana_v(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+
+ model_name = "kakaocorp/kanana-1.5-v-3b-instruct"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=8192,
+ trust_remote_code=True,
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ messages = [
+ [{"role": "user", "content": f"\n{question}"}] for question in questions
+ ]
+ prompts = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
# Keye-VL
def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-8B-Preview"
@@ -1876,6 +1903,7 @@ model_example_map = {
"idefics3": run_idefics3,
"interns1": run_interns1,
"internvl_chat": run_internvl,
+ "kanana_v": run_kanana_v,
"keye_vl": run_keye_vl,
"keye_vl1_5": run_keye_vl1_5,
"kimi_vl": run_kimi_vl,
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 54d18158f..a506408a0 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -728,6 +728,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True,
),
"InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
+ "KananaVForConditionalGeneration": _HfExamplesInfo(
+ "kakaocorp/kanana-1.5-v-3b-instruct",
+ trust_remote_code=True,
+ ),
"KeyeForConditionalGeneration": _HfExamplesInfo(
"Kwai-Keye/Keye-VL-8B-Preview",
trust_remote_code=True,
diff --git a/vllm/model_executor/models/kanana_v.py b/vllm/model_executor/models/kanana_v.py
new file mode 100644
index 000000000..5e1667c1e
--- /dev/null
+++ b/vllm/model_executor/models/kanana_v.py
@@ -0,0 +1,756 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Iterable, Mapping, Sequence
+from functools import partial
+from typing import Annotated, Literal, TypeAlias
+
+import numpy as np
+import regex as re
+import torch
+from einops import rearrange
+from PIL import Image
+from timm.layers import LayerNorm2d
+from timm.layers.pos_embed import resample_abs_pos_embed
+from timm.models.regnet import RegStage
+from torch import nn
+from transformers import BatchFeature
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
+
+from vllm.config import VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.logger import init_logger
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ MultiModalDataDict,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import ImageSize, MultiModalDataItems
+from vllm.multimodal.processing import (
+ BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ PromptReplacement,
+ PromptUpdate,
+)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.utils.import_utils import resolve_obj_by_qualname
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .qwen2_vl import Qwen2VisionTransformer
+from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
+
+logger = init_logger(__name__)
+
+
+class KananaVImagePixelInputs(TensorSchema):
+ """
+ Dimensions:
+ - np: The total number of patches over all images in the batch
+ - cps: Number of channels * patch_size * patch_size
+ - ni: Number of images
+ """
+
+ type: Literal["pixel_values"]
+
+ pixel_values: Annotated[
+ torch.Tensor,
+ TensorShape("np", "cps"),
+ ]
+
+ vision_grid_thw: Annotated[
+ torch.Tensor,
+ TensorShape("ni", 3),
+ ]
+
+
+KananaVImageInputs: TypeAlias = KananaVImagePixelInputs
+
+
+def build_pos_embeds(
+ config: Qwen2VLVisionConfig,
+ num_input_tokens: int,
+ vision_hidden_size: int,
+) -> nn.Parameter | None:
+ """Build positional embeddings for the visual encoder output."""
+ if config.pos_emb:
+ pos_emb = nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
+ nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
+ else:
+ pos_emb = None
+
+ return pos_emb
+
+
+def build_mlp(
+ depth: int,
+ hidden_size: int,
+ output_hidden_size: int,
+) -> nn.Sequential:
+ """Simple SiLU-activated MLP used as a projector readout."""
+ layers = [nn.Linear(hidden_size, output_hidden_size)]
+ for _ in range(1, depth):
+ layers.append(nn.SiLU())
+ layers.append(nn.Linear(output_hidden_size, output_hidden_size))
+ return nn.Sequential(*layers)
+
+
+class PatchMerge(nn.Module):
+ """Merge neighboring patches spatially to reduce resolution."""
+
+ def __init__(self, merge_size: int) -> None:
+ super().__init__()
+ self.merge_size = merge_size
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ channel_last: bool = False,
+ ) -> torch.Tensor:
+ """Merge patches by `merge_size x merge_size`."""
+ if channel_last:
+ x = rearrange(x, "B H W D -> B D H W")
+ _, _, H, W = x.shape
+ merged_x = rearrange(
+ x,
+ "B D (H h2) (W w2) -> B (D h2 w2) H W",
+ h2=self.merge_size,
+ w2=self.merge_size,
+ )
+ return merged_x
+
+
+class DynamicCAbstractor(nn.Module):
+ """Dynamic C-Abstractor based on RegNet blocks."""
+
+ def __init__(
+ self,
+ config: Qwen2VLVisionConfig,
+ num_input_tokens: int,
+ ) -> None:
+ super().__init__()
+ assert hasattr(config, "merge_size"), "merge_size must be provided."
+ self.config = config
+ self.merge_size = config.merge_size
+ self.pos_emb_size = config.pos_emb_size
+ if num_input_tokens == -1:
+ num_input_tokens = config.pos_emb_size
+ self.num_input_tokens = num_input_tokens
+ self.pos_emb = build_pos_embeds(
+ config, num_input_tokens, config.encoder_hidden_size
+ )
+ self.build_net()
+
+ def _load_from_state_dict(self, state_dict, *args, **kwargs) -> None:
+ if not state_dict:
+ return
+
+ if self.pos_emb is not None:
+ key_re = re.compile(r"[\w,.]*abstractor[\w,.]*pos_emb")
+ pos_emb_key = None
+ for key in state_dict:
+ if key_re.match(key):
+ pos_emb_key = key
+ break
+
+ assert pos_emb_key is not None
+ # update old ckpt compatible with current code
+ pos_emb = state_dict[pos_emb_key]
+ if pos_emb.size(1) == self.pos_emb.size(1) + 1:
+ # remove obsolete first pos emb (for cls token originally)
+ state_dict[pos_emb_key] = pos_emb[:, 1:]
+
+ super()._load_from_state_dict(state_dict, *args, **kwargs)
+
+ def build_net(self) -> None:
+ encoder_hidden_size = self.config.encoder_hidden_size
+ hidden_size = self.config.hidden_size
+ output_hidden_size = self.config.output_hidden_size
+ depth = self.config.depth
+ mlp_depth = self.config.mlp_depth
+
+ RegBlock = partial(
+ RegStage,
+ stride=1,
+ dilation=1,
+ act_layer=nn.SiLU,
+ norm_layer=LayerNorm2d,
+ )
+
+ s1 = RegBlock(
+ depth,
+ encoder_hidden_size,
+ hidden_size,
+ )
+ sampler = PatchMerge(merge_size=self.merge_size)
+ s2 = RegBlock(
+ depth,
+ self.merge_size**2 * hidden_size,
+ hidden_size,
+ )
+
+ if depth:
+ self.net = nn.ModuleList([s1, sampler, s2])
+ self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
+ else:
+ self.net = sampler
+ self.readout = build_mlp(mlp_depth, encoder_hidden_size, output_hidden_size)
+
+ def forward(
+ self,
+ flattened_visual_embeds: torch.Tensor,
+ grid_thw: torch.Tensor,
+ **unused_kwargs: object,
+ ) -> BaseModelOutput:
+ """Apply the dynamic abstractor over flattened visual embeddings."""
+ n_token_loc = torch.prod(grid_thw, dim=1)
+ split_visual_embeds = torch.split(flattened_visual_embeds, n_token_loc.tolist())
+
+ flattened_visual_embeds = []
+ for _visual_embeds, _grid_thw in zip(split_visual_embeds, grid_thw):
+ T, H, W = _grid_thw
+ assert T == 1, "T must be 1. Video is not supported yet."
+ reshaped_visual_embeds = rearrange(
+ _visual_embeds, "(t h w) d -> 1 t h w d", t=T, h=H, w=W
+ )
+ # remove temporal dim
+ reshaped_visual_embeds = reshaped_visual_embeds[:, 0]
+
+ if self.pos_emb is not None:
+ # interpolate pos emb and add to visual embeds
+ _local_pos_emb = resample_abs_pos_embed(
+ posemb=self.pos_emb,
+ old_size=tuple([int(self.pos_emb_size**0.5)] * 2),
+ new_size=(H, W),
+ num_prefix_tokens=0,
+ )
+ _local_pos_emb = rearrange(
+ _local_pos_emb,
+ "1 (h w) d -> 1 h w d",
+ h=H,
+ w=W,
+ )
+ reshaped_visual_embeds = reshaped_visual_embeds + _local_pos_emb
+
+ reshaped_visual_embeds = self._forward(
+ reshaped_visual_embeds,
+ input_size=(H, W),
+ )
+ flattened_visual_embeds.append(reshaped_visual_embeds)
+ reshaped_visual_embeds = torch.cat(flattened_visual_embeds, dim=0)
+ return BaseModelOutput(last_hidden_state=reshaped_visual_embeds)
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ input_size: tuple[int, int],
+ ) -> torch.Tensor:
+ h, w = input_size
+ x = rearrange(x, "1 h w d -> 1 d h w", h=h, w=w)
+ if self.config.depth:
+ x = self.net[0](x)
+ x = self.net[1](x)
+ x = self.net[2](x)
+ else:
+ # When depth=0, self.net is a single PatchMerge module
+ x = self.net(x)
+ x = rearrange(x, "1 d h w -> (h w) d")
+ x = self.readout(x)
+ return x
+
+
+class CustomQwen2VLVE(Qwen2VisionTransformer):
+ """Thin wrapper around the Qwen2-VL used as a vision encoder.
+
+ This mirrors the original HF-based vision encoder used in Kanana-V, but
+ reuses vLLM's optimized `Qwen2VisionTransformer` building blocks.
+ """
+
+ def __init__(self, config: Qwen2VLVisionConfig) -> None:
+ super().__init__(
+ vision_config=config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=None,
+ prefix="",
+ )
+
+ # Kanana-V uses its own projector/abstractor instead of the Qwen2
+ # built-in patch merger, so we drop the merger module to keep the
+ # parameter set compatible with the original checkpoint.
+ if hasattr(self, "merger"):
+ del self.merger
+
+ @classmethod
+ def _from_config(cls, config: Qwen2VLVisionConfig) -> "CustomQwen2VLVE":
+ """Drop-in replacement for the HF `_from_config` constructor."""
+ return cls(config)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: torch.Tensor,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ ) -> tuple | BaseModelOutput:
+ """Run the vision transformer and optionally return intermediate states.
+
+ Unlike the base `Qwen2VisionTransformer`, this wrapper exposes the
+ pre-merger patch-level representations and a HF-style `BaseModelOutput`
+ so that the existing projector / abstractor code can be reused.
+ """
+ assert return_dict, "Only return_dict=True is supported."
+
+ # Patchify
+ x = pixel_values.to(device=self.device, dtype=self.dtype)
+ x = self.patch_embed(x) # (num_patches, embed_dim)
+
+ # Prepare grid and rotary embeddings – mirror base implementation.
+ if isinstance(grid_thw, list):
+ grid_thw_list = grid_thw
+ grid_thw_np = np.array(grid_thw, dtype=np.int32)
+ else:
+ grid_thw_list = grid_thw.tolist()
+ grid_thw_np = grid_thw.cpu().numpy()
+
+ rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
+
+ # Compute cu_seqlens in numpy then move to device, same as base model.
+ cu_seqlens = np.repeat(
+ grid_thw_np[:, 1] * grid_thw_np[:, 2],
+ grid_thw_np[:, 0],
+ ).cumsum(axis=0, dtype=np.int32)
+ cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
+ cu_seqlens = torch.from_numpy(cu_seqlens).to(
+ self.device,
+ non_blocking=True,
+ )
+
+ # Shape to (S, B, D) with batch dimension 1 as expected by the blocks.
+ x = x.unsqueeze(1)
+
+ # Pre-compute seqlens for attention backend.
+ max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
+
+ encoder_states = () if output_hidden_states else None
+
+ for blk in self.blocks:
+ if output_hidden_states:
+ # Store patch-level states (S, D).
+ encoder_states = encoder_states + (x.squeeze(1),)
+
+ x = blk(
+ x,
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb_cos=rotary_pos_emb_cos,
+ rotary_pos_emb_sin=rotary_pos_emb_sin,
+ max_seqlen=max_seqlen,
+ )
+
+ # Final hidden state at patch level (S, D).
+ hidden_states = x.squeeze(1)
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ )
+
+ def get_num_tokens(self) -> int:
+ # Not used in the current Kanana-V pipeline, kept for API compatibility.
+ return -1
+
+
+class KananaVProcessingInfo(BaseProcessingInfo):
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None}
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ max_image_size, _ = self._get_vision_info(
+ image_width=9999,
+ image_height=9999,
+ num_frames=1,
+ )
+ return max_image_size
+
+ def _get_vision_info(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int = 1,
+ do_resize: bool = True,
+ ) -> tuple[ImageSize, int]:
+ image_processor = self.ctx.get_hf_processor().image_processor
+ smart_resize = resolve_obj_by_qualname(
+ f"{type(image_processor).__module__}.smart_resize"
+ )
+
+ hf_config = self.get_hf_config()
+ vision_config = hf_config.vision_config
+ patch_size = vision_config.patch_size
+ merge_size = vision_config.spatial_merge_size
+ temporal_patch_size = vision_config.temporal_patch_size
+
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height=image_height,
+ width=image_width,
+ factor=patch_size * merge_size,
+ min_pixels=image_processor.min_pixels,
+ max_pixels=image_processor.max_pixels,
+ )
+ preprocessed_size = ImageSize(width=resized_width, height=resized_height)
+ else:
+ preprocessed_size = ImageSize(width=image_width, height=image_height)
+
+ # NOTE: Frames are padded to be divisible by `temporal_patch_size`
+ # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
+ padded_num_frames = num_frames + num_frames % temporal_patch_size
+
+ grid_t = max(padded_num_frames // temporal_patch_size, 1)
+ grid_h = preprocessed_size.height // patch_size
+ grid_w = preprocessed_size.width // patch_size
+
+ num_patches = grid_t * grid_h * grid_w
+ num_vision_tokens = num_patches // (merge_size**2)
+
+ return preprocessed_size, num_vision_tokens
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ target_width, target_height = self.get_image_size_with_most_features()
+ num_vision_tokens = self._get_vision_info(
+ image_width=target_width,
+ image_height=target_height,
+ num_frames=1,
+ )[1]
+ return {"image": num_vision_tokens}
+
+
+class KananaVDummyInputsBuilder(BaseDummyInputsBuilder[KananaVProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ return "" * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+ return {
+ "image": self._get_dummy_images(
+ width=9999, height=9999, num_images=num_images
+ ),
+ }
+
+
+class KananaVMultiModalProcessor(BaseMultiModalProcessor[KananaVProcessingInfo]):
+ """vLLM multimodal processor for Kanana-V (text + image)."""
+
+ @property
+ def media_token_id(self) -> int:
+ return self.info.get_hf_config().text_config.eos_token_id + 1
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ """Run the underlying HF processor on text and image data."""
+ # Text-only input is handled as a special case here.
+ if not mm_data or not mm_data.get("images", []):
+ prompt_ids = self.info.get_tokenizer().encode(prompt)
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ # Images
+ image_inputs = mm_data.get("images", [])
+ pixel_sizes = []
+ if not isinstance(image_inputs[0], Image.Image):
+ image_inputs = [Image.fromarray(image) for image in image_inputs]
+
+ image_processor = self.info.get_hf_processor().image_processor
+ processor_output = [image_processor(image) for image in image_inputs]
+ pixel_values = [o["pixel_values"] for o in processor_output]
+ image_meta = [o["image_meta"] for o in processor_output]
+ # list of dict -> dict of list
+ image_meta = {k: [d[k] for d in image_meta] for k in image_meta[0]}
+
+ for pixel_value in pixel_values:
+ pixel_sizes.append(pixel_value.shape[0])
+ # flattened pixel_values for single example (already includes batch dim)
+ pixel_values = torch.concat(pixel_values, dim=0)
+
+ tokenizer = self.info.get_tokenizer()
+ media_token = tokenizer.convert_ids_to_tokens([self.media_token_id])[0]
+ prompt_replaced = prompt.replace("", media_token)
+ input_ids = tokenizer.encode(prompt_replaced)
+ input_ids = torch.tensor(input_ids)
+
+ # Ensure HF output is consistent with vLLM prompt-update expectations:
+ # if the HF tokenizer emits exactly 1 placeholder token per image, expand
+ # it to `T*H*W` placeholder tokens per image so placeholder detection works.
+ num_images = len(image_inputs)
+ image_token_thw = torch.tensor(image_meta["image_token_thw"])
+ per_image_token_counts = image_token_thw.prod(dim=1).tolist()
+ expected_total = int(sum(int(x) for x in per_image_token_counts))
+
+ n_placeholders = int((input_ids == self.media_token_id).sum().item())
+ if n_placeholders == num_images and expected_total != num_images:
+ expanded: list[int] = []
+ img_i = 0
+ for tok in input_ids.tolist():
+ if tok == self.media_token_id and img_i < num_images:
+ expanded.extend(
+ [self.media_token_id] * int(per_image_token_counts[img_i])
+ )
+ img_i += 1
+ else:
+ expanded.append(tok)
+ input_ids = input_ids.new_tensor(expanded)
+
+ combined_outputs = dict(
+ # Add batch dimension to input_ids.
+ input_ids=input_ids.unsqueeze(0),
+ pixel_values=pixel_values,
+ vision_grid_thw=torch.tensor(image_meta["vision_grid_thw"]),
+ image_token_thw=torch.tensor(image_meta["image_token_thw"]),
+ pixel_sizes=torch.tensor(pixel_sizes),
+ )
+ return BatchFeature(combined_outputs, tensor_type="pt")
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ def get_replacement(idx: int) -> Sequence[int]:
+ out_item = out_mm_kwargs["image"][idx]
+ image_token_thw = out_item["image_token_thw"].data
+ assert isinstance(image_token_thw, torch.Tensor)
+
+ num_tokens = int(image_token_thw.prod().item())
+ return [self.media_token_id] * num_tokens
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target="",
+ replacement=get_replacement,
+ ),
+ ]
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ pixel_sizes = hf_inputs.get("pixel_sizes", torch.empty(0))
+
+ mm_fields_config = dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes("image", pixel_sizes),
+ vision_grid_thw=MultiModalFieldConfig.batched("image"),
+ image_token_thw=MultiModalFieldConfig.batched("image"),
+ )
+ return mm_fields_config
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ KananaVMultiModalProcessor,
+ info=KananaVProcessingInfo,
+ dummy_inputs=KananaVDummyInputsBuilder,
+)
+class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ if modality.startswith("image"):
+ return ""
+ else:
+ raise ValueError(f"Unsupported modality: {modality}")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ self.config = config
+
+ self.vision_model = CustomQwen2VLVE._from_config(config.vision_config)
+ self.abstractor = DynamicCAbstractor(
+ config.projector_config, num_input_tokens=self.vision_model.get_num_tokens()
+ )
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "model"),
+ architectures=["LlamaForCausalLM"],
+ )
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object
+ ) -> KananaVImageInputs | None:
+ pixel_values = kwargs.pop("pixel_values", None)
+ vision_grid_thw = kwargs.pop("vision_grid_thw", None)
+
+ if pixel_values is None:
+ return None
+
+ if vision_grid_thw is None:
+ raise ValueError(
+ "vision_grid_thw is required when pixel_values is provided"
+ )
+
+ # Normalize pixel_values to 2D tensor (num_patches, channels*patch*patch)
+ if isinstance(pixel_values, torch.Tensor):
+ if pixel_values.ndim == 2:
+ pass # Already in expected shape
+ elif pixel_values.ndim == 3:
+ pixel_values = pixel_values.flatten(0, 1)
+ else:
+ raise ValueError(
+ f"pixel_values should be 2D or batched 3D tensor. "
+ f"Got ndim: {pixel_values.ndim} "
+ f"(shape={pixel_values.shape})"
+ )
+ else:
+ pixel_values = torch.concat(pixel_values)
+
+ # Normalize vision_grid_thw to 2D tensor (num_images, 3)
+ if isinstance(vision_grid_thw, torch.Tensor):
+ if vision_grid_thw.ndim == 3:
+ vision_grid_thw = vision_grid_thw.flatten(0, 1)
+ else:
+ vision_grid_thw = torch.concat(vision_grid_thw)
+
+ return KananaVImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values,
+ vision_grid_thw=vision_grid_thw,
+ )
+
+ def _process_image_input(self, image_input: KananaVImageInputs) -> torch.Tensor:
+ pixel_values = image_input["pixel_values"]
+ vision_grid_thw = image_input["vision_grid_thw"]
+
+ image_metas = {"vision_grid_thw": vision_grid_thw}
+ visual_embeds = self.forward_and_project_vision(pixel_values, image_metas)
+
+ merge_size = self.abstractor.merge_size
+ batch_size = vision_grid_thw.size(0)
+ multi_modal_embeddings: tuple[torch.Tensor, ...] = ()
+ sample_index = 0
+ for i in range(batch_size):
+ t, h, w = (
+ vision_grid_thw[i][0],
+ vision_grid_thw[i][1] // merge_size,
+ vision_grid_thw[i][2] // merge_size,
+ )
+ num_tokens = t * h * w
+ visual_embed = visual_embeds[sample_index : sample_index + num_tokens]
+ multi_modal_embeddings += (visual_embed,)
+ sample_index += num_tokens
+
+ return multi_modal_embeddings
+
+ def _get_visual_feature_at(
+ self,
+ v_output: Sequence[torch.Tensor],
+ layer_index: int | Sequence[int],
+ ) -> torch.Tensor:
+ if isinstance(layer_index, (list, tuple)):
+ visual_features = torch.stack(v_output, dim=1)[
+ :, layer_index
+ ] # [B, n_scales, L, dim]
+ else:
+ visual_features = v_output[layer_index] # [B, L, dim]
+ return visual_features
+
+ def forward_vision(
+ self,
+ pixel_values: torch.Tensor,
+ image_metas: dict | None = None,
+ ) -> torch.Tensor:
+ vision_model_args = {
+ "pixel_values": pixel_values,
+ "return_dict": True,
+ "output_hidden_states": True,
+ "grid_thw": image_metas["vision_grid_thw"],
+ }
+ v_outputs = self.vision_model(**vision_model_args)
+ layer_index = self.config.projector_config.feature_layer_index
+ visual_features = self._get_visual_feature_at(
+ v_outputs.hidden_states, layer_index
+ )
+ return visual_features
+
+ def forward_projector(
+ self,
+ visual_features: torch.Tensor,
+ image_metas: dict | None = None,
+ ) -> torch.Tensor:
+ visual_embeds = self.abstractor(
+ visual_features,
+ grid_thw=image_metas["vision_grid_thw"],
+ )["last_hidden_state"]
+ return visual_embeds
+
+ def forward_and_project_vision(
+ self,
+ pixel_values: torch.Tensor,
+ image_metas: dict | None = None,
+ ) -> torch.Tensor:
+ assert pixel_values is not None
+ visual_features = self.forward_vision(pixel_values, image_metas=image_metas)
+ visual_embeds = self.forward_projector(visual_features, image_metas=image_metas)
+ return visual_embeds
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return []
+
+ return self._process_image_input(image_input)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs,
+ ):
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ hidden_states = self.language_model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+
+ return hidden_states
+
+ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.language_model.compute_logits(hidden_states)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 0b165232c..3818b83b6 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -343,6 +343,7 @@ _MULTIMODAL_MODELS = {
),
"IsaacForConditionalGeneration": ("isaac", "IsaacForConditionalGeneration"),
"SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"), # noqa: E501
+ "KananaVForConditionalGeneration": ("kanana_v", "KananaVForConditionalGeneration"),
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
"KeyeVL1_5ForConditionalGeneration": (
"keye_vl1_5",