diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index a38a88ce3..6d5ba168b 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -686,6 +686,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
+| `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I+ | `moonshotai/Kimi-K2.5` | | ✅︎ |
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I+ | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
| `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I+ | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 5c6db71b1..bea86c081 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -771,6 +771,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
)
},
),
+ "KimiK25ForConditionalGeneration": _HfExamplesInfo(
+ "moonshotai/Kimi-K2.5",
+ trust_remote_code=True,
+ is_available_online=False,
+ ),
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
"lightonai/LightOnOCR-1B-1025"
),
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index b4f2e9dbe..bcc8d3c65 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -46,6 +46,9 @@ from vllm.multimodal.inputs import (
MultiModalBatchedField,
MultiModalFlatField,
MultiModalSharedField,
+ VisionChunk,
+ VisionChunkImage,
+ VisionChunkVideo,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
@@ -336,7 +339,9 @@ ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
ChatTemplateContentFormat = Literal["string", "openai"]
-ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
+ModalityStr = Literal[
+ "image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
+]
_T = TypeVar("_T")
@@ -449,6 +454,78 @@ def _get_embeds_data(
raise NotImplementedError(type(data_items))
+def rebuild_mm_uuids_from_mm_data(
+ mm_uuids: MultiModalUUIDDict,
+ mm_data: MultiModalDataDict,
+) -> MultiModalUUIDDict:
+ """Rebuild mm_uuids after vision_chunk processing.
+
+ When videos are split into chunks, the original UUIDs need to be updated
+ to reflect the new UUIDs generated for each chunk.
+
+ Args:
+ mm_uuids: Original UUIDs dictionary
+ mm_data: Processed multimodal data with vision_chunk items
+
+ Returns:
+ Updated UUIDs dictionary with chunk UUIDs
+ """
+ vision_chunks = mm_data.get("vision_chunk")
+ if vision_chunks is None:
+ return mm_uuids
+
+ new_uuids = dict(mm_uuids)
+ vision_chunk_uuids = []
+
+ for item in vision_chunks:
+ # vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
+ assert isinstance(item, dict)
+ uuid_val = item.get("uuid")
+ if uuid_val is not None:
+ vision_chunk_uuids.append(uuid_val)
+
+ if vision_chunk_uuids:
+ new_uuids["vision_chunk"] = vision_chunk_uuids
+
+ return new_uuids
+
+
+def build_video_prompts_from_mm_data(
+ mm_data: MultiModalDataDict,
+) -> list[str]:
+ """Build video prompts from vision_chunk data.
+
+ Collects prompts from video chunks and groups them by video_idx.
+
+ Args:
+ mm_data: Processed multimodal data with vision_chunk items
+
+ Returns:
+ List of video prompts, one per video.
+ """
+ vision_chunks = mm_data.get("vision_chunk")
+ if vision_chunks is None:
+ return []
+
+ # Group chunks by video_idx
+ video_prompts_dict: dict[int, list[str]] = defaultdict(list)
+
+ for item in vision_chunks:
+ # vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
+ assert isinstance(item, dict)
+ if item.get("type") == "video_chunk":
+ video_idx = item.get("video_idx", 0)
+ prompt = item.get("prompt", "")
+ video_prompts_dict[video_idx].append(prompt)
+
+ # Build prompts in video order
+ video_prompts = []
+ for video_idx in sorted(video_prompts_dict.keys()):
+ video_prompts.append("".join(video_prompts_dict[video_idx]))
+
+ return video_prompts
+
+
class BaseMultiModalItemTracker(ABC, Generic[_T]):
"""
Tracks multi-modal items in a given request and ensures that the number
@@ -462,6 +539,13 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config
self._items_by_modality = defaultdict[str, list[_T]](list)
+ # Track original modality for each vision_chunk item (image or video)
+ self._modality_order = defaultdict[str, list[str]](list)
+
+ @cached_property
+ def use_unified_vision_chunk_modality(self) -> bool:
+ """Check if model uses unified vision_chunk modality for images/videos."""
+ return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
@property
def model_config(self) -> ModelConfig:
@@ -499,11 +583,31 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
media.
"""
input_modality = modality.replace("_embeds", "")
- num_items = len(self._items_by_modality[modality]) + 1
+ original_modality = modality
+ use_vision_chunk = (
+ self.use_unified_vision_chunk_modality
+ and original_modality in ["video", "image"]
+ )
+
+ # If use_unified_vision_chunk_modality is enabled,
+ # map image/video to vision_chunk
+ if use_vision_chunk:
+ # To avoid validation fail
+ # because models with use_unified_vision_chunk_modality=True
+ # will only accept vision_chunk modality.
+ input_modality = "vision_chunk"
+ num_items = len(self._items_by_modality[input_modality]) + 1
+ else:
+ num_items = len(self._items_by_modality[original_modality]) + 1
self.mm_processor.validate_num_items(input_modality, num_items)
- self._items_by_modality[modality].append(item)
+ # Track original modality for vision_chunk items
+ if use_vision_chunk:
+ self._items_by_modality[input_modality].append(item) # type: ignore
+ self._modality_order["vision_chunk"].append(original_modality)
+ else:
+ self._items_by_modality[original_modality].append(item)
return self.model_cls.get_placeholder_str(modality, num_items)
@@ -515,6 +619,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def _resolve_items(
items_by_modality: dict[str, list[tuple[object, str | None]]],
mm_processor: BaseMultiModalProcessor,
+ vision_chunk_modality_order: dict[str, list[str]],
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
@@ -546,6 +651,74 @@ def _resolve_items(
if "video" in items_by_modality:
mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
+ if "vision_chunk" in items_by_modality:
+ # Process vision_chunk items - extract from (data, modality) tuples
+ # and convert to VisionChunk types with proper UUID handling
+ vision_chunk_items = items_by_modality["vision_chunk"]
+ modality_order = vision_chunk_modality_order.get("vision_chunk", [])
+ mm_uuids["vision_chunk"] = [
+ uuid for data, uuid in items_by_modality["vision_chunk"]
+ ]
+
+ # Filter out None items (from asyncio.sleep(0) placeholders)
+ filtered_items = [
+ (idx, item)
+ for idx, item in enumerate(vision_chunk_items)
+ if item is not None
+ ]
+
+ assert len(filtered_items) == len(modality_order), (
+ f"vision_chunk items ({len(filtered_items)}) and "
+ f"modality_order ({len(modality_order)}) must have same length"
+ )
+
+ processed_chunks: list[VisionChunk] = []
+ video_idx = 0
+ for i, (idx, item) in enumerate(filtered_items):
+ inner_modality = modality_order[i]
+ data, uuid = item
+ uuid_val = uuid if idx < len(mm_uuids["vision_chunk"]) else None
+ if inner_modality == "image":
+ # Cast data to proper type for image
+ # Use .media (PIL.Image) directly to avoid redundant
+ # bytes→PIL conversion in media_processor
+ if hasattr(data, "media"):
+ image_data = data.media # type: ignore[union-attr]
+ processed_chunks.append(
+ VisionChunkImage(type="image", image=image_data, uuid=uuid_val)
+ )
+ else:
+ processed_chunks.append(data) # type: ignore[arg-type]
+ elif inner_modality == "video":
+ # For video, we may need to split into chunks
+ # if processor supports it
+ # For now, just wrap as a video chunk placeholder
+ if hasattr(mm_processor, "split_video_chunks") and data is not None:
+ try:
+ video_uuid = uuid_val or random_uuid()
+ # video await result is (video_data, video_meta) tuple
+ if isinstance(data, tuple) and len(data) >= 1:
+ video_data = data[0]
+ else:
+ video_data = data
+ video_chunks = mm_processor.split_video_chunks(video_data)
+ for i, vc in enumerate(video_chunks):
+ processed_chunks.append(
+ VisionChunkVideo(
+ type="video_chunk",
+ video_chunk=vc["video_chunk"],
+ uuid=f"{video_uuid}-{i}",
+ video_idx=video_idx,
+ prompt=vc["prompt"],
+ )
+ )
+ video_idx += 1
+ except Exception as e:
+ logger.warning("Failed to split video chunks: %s", e)
+ processed_chunks.append(data) # type: ignore[arg-type]
+ else:
+ processed_chunks.append(data) # type: ignore[arg-type]
+ mm_data["vision_chunk"] = processed_chunks
return mm_data, mm_uuids
@@ -557,7 +730,9 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]
if not self._items_by_modality:
return None, None
- return _resolve_items(dict(self._items_by_modality), self.mm_processor)
+ return _resolve_items(
+ dict(self._items_by_modality), self.mm_processor, self._modality_order
+ )
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
@@ -577,7 +752,9 @@ class AsyncMultiModalItemTracker(
for modality, coros in self._items_by_modality.items()
}
- return _resolve_items(resolved_items_by_modality, self.mm_processor)
+ return _resolve_items(
+ resolved_items_by_modality, self.mm_processor, self._modality_order
+ )
def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)
diff --git a/vllm/envs.py b/vllm/envs.py
index d6243c02d..4541de72c 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -782,6 +782,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Backend for Video IO
# - "opencv": Default backend that uses OpenCV stream buffered backend.
+ # - "identity": Returns raw video bytes for model processor to handle.
#
# Custom backend implementations can be registered
# via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and
diff --git a/vllm/model_executor/models/kimi_k25.py b/vllm/model_executor/models/kimi_k25.py
new file mode 100644
index 000000000..dccf05c14
--- /dev/null
+++ b/vllm/model_executor/models/kimi_k25.py
@@ -0,0 +1,581 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# ruff: noqa: E501
+"""
+Kimi-K2.5 Model Implementation for vLLM.
+
+Kimi-K2.5 extends Kimi-K2 with vision support
+
+This module defines:
+- KimiK25ProcessingInfo/KimiK25MultiModalProcessor: Processing logic
+- KimiK25ForConditionalGeneration: Main model class
+"""
+
+import copy
+from collections.abc import Iterable, Mapping, Sequence
+from dataclasses import dataclass
+from typing import Annotated, Any, Literal
+
+import torch
+from torch import nn
+from transformers import BatchFeature
+from transformers.processing_utils import ProcessorMixin
+
+from vllm.config import VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.distributed import get_pp_group
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe import SharedFusedMoE
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
+from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
+from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
+from vllm.model_executor.models.kimi_k25_vit import (
+ KimiK25MultiModalProjector,
+ MoonViT3dPretrainedModel,
+ vision_tower_forward,
+)
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ MultiModalDataDict,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+ NestedTensors,
+ VisionChunk,
+ VisionChunkImage,
+ VisionChunkVideo,
+)
+from vllm.multimodal.parse import MultiModalDataItems, VisionChunkProcessorItems
+from vllm.multimodal.processing import (
+ BaseDummyInputsBuilder,
+ BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ InputProcessingContext,
+ PromptReplacement,
+ PromptUpdate,
+)
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs import KimiK25Config
+from vllm.transformers_utils.processor import cached_get_image_processor
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
+
+logger = init_logger(__name__)
+
+
+# Dummy input dimensions for profiling.
+@dataclass
+class MaxImageTokenMeta:
+ width: int = 3000
+ height: int = 3000
+
+
+class KimiK25MediaPixelInputs(TensorSchema):
+ """
+ Media input schema for K2-VL model.
+
+ Dimensions:
+ - np: Number of patches (flattened from all media items)
+ - ps: Patch size
+ - nm: Number of media items
+ """
+
+ type: Literal["pixel_values"] = "pixel_values"
+
+ pixel_values: Annotated[
+ torch.Tensor | list[torch.Tensor],
+ TensorShape("np", 3, "ps", "ps"),
+ ]
+
+ grid_thws: Annotated[torch.Tensor, TensorShape("nm", 3)]
+
+
+class MoonshotKimiVAutoProcessor(ProcessorMixin):
+ attributes = ["tokenizer"]
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, media_processor=None, tokenizer=None):
+ super().__init__(tokenizer)
+ self.media_processor = media_processor
+
+ # We do not support str input for text here
+ def __call__(
+ self,
+ vision_chunks: list[VisionChunk] | None = None,
+ *,
+ text: list[int],
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Args:
+ vision_chunks: List of VisionChunk items to be processed.
+ For image: VisionChunkImage with type='image', image=PIL.Image
+ For video_chunk: VisionChunkVideo with type='video_chunk', video_chunk=list[PIL.Image]
+ text: The token ids to be fed to a model (required).
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- list of token ids to be fed to a model.
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `vision_chunks` is not `None`.
+ - **grid_thws** -- list of image 3D grid in LLM. Returned when `vision_chunks` is not `None`.
+ """
+ mm_inputs = {}
+ if vision_chunks is not None:
+ assert isinstance(vision_chunks, list)
+ mm_inputs = self.media_processor.preprocess(vision_chunks)
+ # XXX: _apply_hf_processor_text_mm will call tolist() on input_ids
+ return BatchFeature(
+ data={
+ "input_ids": torch.tensor([text]),
+ **mm_inputs,
+ }
+ )
+
+
+class KimiK25ProcessingInfo(BaseProcessingInfo):
+ """Processing information for Kimi-K2.5 model.
+
+ Provides configuration and utilities for processing both
+ images and video-chunks.
+ """
+
+ def __init__(self, ctx: InputProcessingContext) -> None:
+ super().__init__(ctx)
+ self.hf_config = self.get_hf_config()
+ self.media_token_id = self.hf_config.media_placeholder_token_id
+ media_processor = cached_get_image_processor(
+ self.ctx.model_config.model, trust_remote_code=True
+ )
+ self.media_processor = media_processor
+ self.hf_processor = MoonshotKimiVAutoProcessor(
+ media_processor=self.media_processor,
+ tokenizer=self.get_tokenizer(),
+ )
+ self.media_tokens_calculator = self.media_processor.media_tokens_calculator
+
+ def get_hf_processor(self):
+ return self.hf_processor
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(KimiK25Config)
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ # None means unlimited
+ return {"vision_chunk": None}
+
+
+class KimiK25DummyInputsBuilder(BaseDummyInputsBuilder[KimiK25ProcessingInfo]):
+ """Builds dummy inputs for Kimi-K2.5 model profiling."""
+
+ def __init__(self, info: KimiK25ProcessingInfo) -> None:
+ super().__init__(info)
+ self.media_token_id = self.info.media_token_id
+ self.frame_per_chunk = self.info.media_processor.num_frames_per_chunk
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
+ num_media = mm_counts.get("vision_chunk", 0)
+ return [self.media_token_id] * num_media
+
+ def get_dummy_mm_items(self):
+ dummy_videos = self._get_dummy_images(
+ height=MaxImageTokenMeta.height,
+ width=MaxImageTokenMeta.width,
+ num_images=self.frame_per_chunk,
+ )
+
+ video_chunk_dummy_item = VisionChunkVideo(
+ type="video_chunk", video_chunk=dummy_videos
+ )
+ video_chunk_num_tokens = self.info.media_tokens_calculator(
+ video_chunk_dummy_item
+ )
+
+ image_dummy_item = VisionChunkImage(
+ type="image",
+ image=self._get_dummy_images(
+ height=MaxImageTokenMeta.height,
+ width=MaxImageTokenMeta.width,
+ num_images=1,
+ )[0],
+ )
+ image_num_tokens = self.info.media_tokens_calculator(image_dummy_item)
+ # return the larger one
+ if video_chunk_num_tokens >= image_num_tokens:
+ return [video_chunk_dummy_item]
+ else:
+ return [image_dummy_item]
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
+ ) -> MultiModalDataDict:
+ # TODO: Support mm_options for vision_chunk to allow user configuration
+ dummy_items = self.get_dummy_mm_items()
+ return {"vision_chunk": dummy_items}
+
+
+class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo]):
+ """Multi-modal processor for Kimi-K2.5.
+
+ Handles both image and video-chunk modalities.
+ """
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ """Indicates how to slice media input into multiple items.
+
+ pixel_values: [N, 3, patch_size, patch_size], all patches collected from B medias
+ grid_thws: [B,3], each item: [N_t, N_h ,N_w], indicates the grid size in time/height/width direction
+ for current item.
+
+ by multiplying [N_t, N_h ,N_w], we get the number of patches for each media item, thus we can slice
+ pixel_values by pixel_values[start:start + N_t*N_h*N_w] to get patches of one item.
+
+ """
+ grid_thws = hf_inputs.get("grid_thws", torch.empty((0, 3)))
+ grid_sizes = grid_thws.prod(-1)
+
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "vision_chunk", grid_sizes
+ ),
+ grid_thws=MultiModalFieldConfig.batched("vision_chunk"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ hf_config = self.info.get_hf_config()
+ media_token_id = hf_config.media_placeholder_token_id
+
+ def get_replacement(item_idx: int):
+ media = mm_items.get_items("vision_chunk", (VisionChunkProcessorItems,))
+ num_media_token = self.info.media_tokens_calculator(media[item_idx])
+ return [media_token_id] * num_media_token
+
+ return [
+ PromptReplacement(
+ modality="vision_chunk",
+ target=[media_token_id],
+ replacement=get_replacement,
+ ),
+ ]
+
+ def split_video_chunks(self, video):
+ return self.info.media_processor.split_video_chunks(video)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ KimiK25MultiModalProcessor,
+ info=KimiK25ProcessingInfo,
+ dummy_inputs=KimiK25DummyInputsBuilder,
+)
+class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
+ """Kimi-K2.5 model for conditional generation.
+
+ Supports both image and video-chunk modalities.
+ Video-chunks are temporal segments (typically 4 frames) that are
+ processed with temporal pooling.
+ """
+
+ supports_encoder_tp_data = True
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ # Kimi-K2.5 uses video_chunk for all media types
+ if modality == "image":
+ return "<|media_begin|>image<|media_content|><|media_pad|><|media_end|>"
+ elif modality == "video":
+ # return a placeholder, to be replaced in the future.
+ return "<|kimi_k25_video_placeholder|>"
+
+ raise ValueError(f"Unsupported modality: {modality}")
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ model_config = vllm_config.model_config
+ config: KimiK25Config = model_config.hf_config
+ self.config = config
+ quant_config = vllm_config.quant_config
+
+ # Check for MoonViT config compatibility
+ self.use_data_parallel = (
+ model_config.multimodal_config.mm_encoder_tp_mode == "data"
+ )
+ self.hidden_size = config.text_config.hidden_size
+ self.device = torch.cuda.current_device()
+ # Build vision tower directly with KimiK25VisionConfig
+ self.vision_tower = MoonViT3dPretrainedModel(
+ config.vision_config,
+ prefix=maybe_prefix(prefix, "vision_tower"),
+ )
+ self.vision_tower = self.vision_tower.to(
+ device=self.device, dtype=model_config.dtype
+ )
+
+ self.mm_projector = KimiK25MultiModalProjector(
+ config=config.vision_config,
+ use_data_parallel=self.use_data_parallel,
+ prefix=maybe_prefix(prefix, "mm_projector"),
+ )
+ self.mm_projector = self.mm_projector.to(
+ device=self.device, dtype=model_config.dtype
+ )
+
+ self.quant_config = quant_config
+ sub_vllm_config = copy.deepcopy(vllm_config)
+ sub_vllm_config.model_config.hf_config = (
+ sub_vllm_config.model_config.hf_config.text_config
+ )
+ self.language_model = DeepseekV2Model(
+ vllm_config=sub_vllm_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ )
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.text_config.hidden_size,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ else:
+ self.lm_head = PPMissingLayer()
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+ logit_scale = getattr(config, "logit_scale", 1.0)
+ self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
+ self.media_placeholder: int = self.config.media_placeholder_token_id
+
+ def _parse_and_validate_media_input(
+ self, **kwargs: object
+ ) -> KimiK25MediaPixelInputs | None:
+ pixel_values = kwargs.pop("pixel_values", None)
+ grid_thws = kwargs.pop("grid_thws", None)
+ if pixel_values is None:
+ return None
+
+ if isinstance(pixel_values, list):
+ pixel_values = torch.cat(pixel_values, dim=0)
+
+ if len(pixel_values.shape) == 5 or len(pixel_values.shape) == 3:
+ pixel_values = pixel_values.reshape(
+ pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
+ )
+
+ # The batch dimension of pixel_values has been flattened into shape[0]
+ target_dtype = next(self.vision_tower.parameters()).dtype
+ pixel_values = pixel_values.to(target_dtype)
+ assert isinstance(grid_thws, torch.Tensor), (
+ f"expect grid_thws to be a tensor, get {type(grid_thws)}"
+ )
+ # In some cases (e.g. with merger), grid_thws has an extra middle dimension
+ grid_thws = grid_thws.reshape(-1, grid_thws.shape[-1])
+ assert grid_thws.ndim == 2 and grid_thws.size(1) == 3, (
+ f"unexpected shape for grid_thws: {grid_thws.shape}"
+ )
+
+ return KimiK25MediaPixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values,
+ grid_thws=grid_thws,
+ )
+
+ def _process_media_input(
+ self, media_input: KimiK25MediaPixelInputs
+ ) -> list[torch.Tensor]:
+ # NOTE(moyan): This forward will automatically batch the forward pass internally
+ media_features = vision_tower_forward(
+ self.vision_tower,
+ media_input["pixel_values"],
+ media_input["grid_thws"],
+ mm_projector=self.mm_projector,
+ use_data_parallel=self.use_data_parallel,
+ )
+ return media_features
+
+ def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
+ # Validate the multimodal input keyword arguments
+ media_input = self._parse_and_validate_media_input(**kwargs)
+ if media_input is None:
+ return None
+
+ # Run multimodal inputs through encoder and projector
+ vision_embeddings = self._process_media_input(media_input)
+ return vision_embeddings
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs: object,
+ ) -> IntermediateTensors:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+ hidden_states = self.language_model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+
+ return hidden_states
+
+ def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
+ logits = self.logits_processor(self.lm_head, hidden_states, **kwargs)
+ return logits
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ config = self.config.text_config
+ if not getattr(config, "n_routed_experts", None):
+ return []
+ return SharedFusedMoE.make_expert_params_mapping(
+ self,
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=config.n_routed_experts,
+ num_redundant_experts=0,
+ )
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
+ config = self.config.text_config
+ _KEYS_TO_MODIFY_MAPPING = {
+ "language_model.lm_head": "lm_head",
+ "language_model.model": "language_model",
+ # mm_projector -> mm_projector mapping
+ # "mm_projector": "mm_projector",
+ "mm_projector.proj.0": "mm_projector.linear_1",
+ "mm_projector.proj.2": "mm_projector.linear_2",
+ }
+ stacked_params_mapping = [
+ (".gate_up_proj", ".gate_proj", 0),
+ (".gate_up_proj", ".up_proj", 1),
+ ]
+ if getattr(config, "kv_lora_rank", None) and getattr(
+ config, "q_lora_rank", None
+ ):
+ stacked_params_mapping += [
+ (".fused_qkv_a_proj", ".q_a_proj", 0),
+ (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1),
+ ]
+ expert_params_mapping = self.get_expert_mapping()
+
+ params_dict = dict(self.named_parameters())
+
+ for args in weights:
+ name, loaded_weight = args[:2]
+ kwargs = args[2] if len(args) > 2 else {}
+ if "rotary_emb.inv_freq" in name:
+ continue
+
+ spec_layer = get_spec_layer_idx_from_weight_name(config, name)
+ if spec_layer is not None:
+ continue # skip spec decode layers for main model
+
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
+ continue
+
+ for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
+ if key_to_modify in name:
+ name = name.replace(key_to_modify, new_key)
+
+ use_default_weight_loading = False
+ if "vision" in name:
+ if self.vision_tower is not None:
+ use_default_weight_loading = True
+ else:
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ if ("mlp.experts." in name) and name not in params_dict:
+ continue
+ name = name.replace(weight_name, param_name)
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id, **kwargs)
+ break
+ else:
+ for _, (
+ param_name,
+ weight_name,
+ expert_id,
+ shard_id,
+ ) in enumerate(expert_params_mapping):
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(
+ param,
+ loaded_weight,
+ name,
+ expert_id=expert_id,
+ shard_id=shard_id,
+ **kwargs,
+ )
+ break
+ else:
+ use_default_weight_loading = True
+
+ if use_default_weight_loading:
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight, **kwargs)
+
+
+def get_spec_layer_idx_from_weight_name(
+ config: KimiK25Config, weight_name: str
+) -> int | None:
+ if hasattr(config, "num_nextn_predict_layers") and (
+ config.num_nextn_predict_layers > 0
+ ):
+ layer_idx = config.num_hidden_layers
+ for i in range(config.num_nextn_predict_layers):
+ # might start with language_model.model.layers
+ if f"model.layers.{layer_idx + i}." in weight_name:
+ return layer_idx + i
+ return None
diff --git a/vllm/model_executor/models/kimi_k25_vit.py b/vllm/model_executor/models/kimi_k25_vit.py
new file mode 100644
index 000000000..650ff7d21
--- /dev/null
+++ b/vllm/model_executor/models/kimi_k25_vit.py
@@ -0,0 +1,678 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Vision tower implementation for Kimi-K2.5 model.
+
+This module provides the vision encoder components for Kimi-K2.5,
+including 3D patch embedding, RoPE position embedding, and
+temporal pooling for video chunks.
+"""
+
+from collections.abc import Sequence
+from copy import deepcopy
+from typing import Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.activations import GELUActivation
+
+from vllm.distributed import divide, get_tensor_model_parallel_world_size
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.models.utils import maybe_prefix
+from vllm.model_executor.models.vision import (
+ is_vit_use_data_parallel,
+ run_dp_sharded_mrope_vision_model,
+)
+from vllm.transformers_utils.configs.kimi_k25 import KimiK25VisionConfig
+
+logger = init_logger(__name__)
+
+
+def _apply_rope_input_validation(x, freqs_cis):
+ assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
+ assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
+ assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
+ assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
+
+
+def get_rope_shape_decorate(func):
+ _get_rope_shape_first_call_flag = set()
+
+ def wrapper(org, interpolation_mode, shape):
+ key = (org.requires_grad, torch.is_grad_enabled(), interpolation_mode)
+ if key not in _get_rope_shape_first_call_flag:
+ _get_rope_shape_first_call_flag.add(key)
+ _ = func(org, interpolation_mode, shape=(64, 64))
+ return func(org, interpolation_mode, shape)
+
+ return wrapper
+
+
+@get_rope_shape_decorate
+@torch.compile(dynamic=True)
+def get_rope_shape(org, interpolation_mode, shape):
+ return (
+ F.interpolate(
+ org.permute((2, 0, 1)).unsqueeze(0),
+ size=shape,
+ mode=interpolation_mode,
+ )
+ .squeeze(0)
+ .permute((1, 2, 0))
+ .flatten(end_dim=1)
+ )
+
+
+def apply_rope(
+ xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args: (The leading dimensions of all inputs should be the same)
+ xq: query, tensor of shape (..., num_heads, head_dim)
+ xk: key, tensor of shape (..., num_heads, head_dim)
+ freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64.
+ Returns:
+ xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
+ """
+ _apply_rope_input_validation(xq, freqs_cis)
+ _apply_rope_input_validation(xk, freqs_cis)
+
+ freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
+ # ..., num_heads, head_dim/2
+ xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """Generate 1D sincos positional embedding from grid positions."""
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
+ """Generate 1D sincos positional embedding."""
+ grid_t = np.arange(t_size, dtype=np.float32)
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+class Learnable2DInterpPosEmbDivided_fixed(nn.Module):
+ """2D learnable position embedding with temporal extension."""
+
+ def __init__(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ dim: int,
+ interpolation_mode: str = "bicubic",
+ ) -> None:
+ super().__init__()
+ self.height = height
+ self.width = width
+ self.num_frames = num_frames
+ self.dim = dim
+ self.interpolation_mode = interpolation_mode
+ self.weight = nn.Parameter(torch.empty(height, width, dim))
+ self.register_buffer(
+ "time_weight",
+ torch.from_numpy(get_1d_sincos_pos_embed(self.dim, self.num_frames))
+ .float()
+ .unsqueeze(1),
+ persistent=False,
+ )
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.normal_(self.weight)
+
+ def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor:
+ pos_embs = []
+ for t, h, w in grid_thws.tolist():
+ assert t <= self.num_frames, f"t:{t} > self.num_frames:{self.num_frames}"
+ if (h, w) == self.weight.shape[:-1]:
+ pos_emb_2d = self.weight.flatten(end_dim=1)
+ else:
+ pos_emb_2d = get_rope_shape(
+ self.weight,
+ interpolation_mode=self.interpolation_mode,
+ shape=(h, w),
+ )
+
+ if t == 1:
+ pos_emb_3d = pos_emb_2d
+ else:
+ pos_emb_3d = (
+ pos_emb_2d.unsqueeze(0).repeat(t, 1, 1) + self.time_weight[0:t]
+ )
+
+ pos_embs.append(pos_emb_3d.reshape(-1, pos_emb_3d.shape[-1]))
+
+ out = x + torch.cat(pos_embs)
+ return out
+
+
+class MoonVision3dPatchEmbed(nn.Module):
+ """3D patch embedding for vision tower."""
+
+ def __init__(
+ self,
+ out_dim: int,
+ in_dim: int = 3,
+ patch_size: int | tuple[int, int] = (14, 14),
+ pos_emb_height: int = 14,
+ pos_emb_width: int = 14,
+ pos_emb_time: int = 4,
+ pos_emb_type: str = "divided_fixed",
+ ):
+ super().__init__()
+ assert isinstance(patch_size, int | Sequence), (
+ f"Invalid patch_size type: {type(patch_size)}"
+ )
+ if isinstance(patch_size, int):
+ patch_size = (patch_size, patch_size)
+ assert len(patch_size) == 2, (
+ f"Expected patch_size to be a tuple of 2, got {patch_size}"
+ )
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(
+ in_dim, out_dim, kernel_size=patch_size, stride=patch_size
+ )
+
+ if pos_emb_type == "divided_fixed":
+ self.pos_emb = Learnable2DInterpPosEmbDivided_fixed(
+ height=pos_emb_height,
+ width=pos_emb_width,
+ num_frames=pos_emb_time,
+ dim=out_dim,
+ )
+ else:
+ raise NotImplementedError(f"Not support pos_emb_type: {pos_emb_type}")
+
+ def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x).view(x.size(0), -1)
+ # apply positional embedding
+ x = self.pos_emb(x, grid_thws)
+ return x
+
+
+class Rope2DPosEmbRepeated(nn.Module):
+ """2D rotary position embedding with multi-resolution support."""
+
+ def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
+ super().__init__()
+ self.dim = dim
+ assert self.dim % 4 == 0, "dim must be divisible by 4"
+ self.max_height = max_height
+ self.max_width = max_width
+ self.theta_base = theta_base
+
+ def extra_repr(self):
+ return (
+ f"dim={self.dim}, max_height={self.max_height}, "
+ f"max_width={self.max_width}, theta_base={self.theta_base}"
+ )
+
+ def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
+ """Calculate the cis(freqs) for each position in the 2D grid."""
+ N = self.max_height * self.max_width
+ flat_pos = torch.arange(0, N).float().to(device)
+ x_pos = flat_pos % self.max_width
+ y_pos = flat_pos // self.max_width
+ dim_range = (
+ torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
+ ) # C/4
+ freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
+ x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
+ # N, C/4, 2
+ freqs_cis = torch.cat(
+ [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
+ )
+ # max_height, max_width, C/2
+ freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
+ return freqs_cis
+
+ def get_freqs_cis(
+ self, grid_thws: torch.Tensor, device: torch.device
+ ) -> torch.Tensor:
+ """
+ Args:
+ grid_thws (torch.Tensor): grid time, height and width
+
+ Returns:
+ freqs_cis: tensor of shape (sum(t * height * width), dim//2)
+ """
+ if not hasattr(self, "freqs_cis"):
+ self.register_buffer(
+ "freqs_cis", self._precompute_freqs_cis(device), persistent=False
+ )
+
+ shapes = grid_thws.tolist()
+ assert all(
+ 1 <= h <= self.max_height and 1 <= w <= self.max_width for t, h, w in shapes
+ ), (
+ shapes,
+ self.max_height,
+ self.max_width,
+ )
+ freqs_cis = torch.cat(
+ [
+ self.freqs_cis[:h, :w].reshape(-1, self.dim // 2).repeat(t, 1)
+ for t, h, w in shapes
+ ],
+ dim=0,
+ )
+ return freqs_cis
+
+
+class MLP2(nn.Module):
+ """Two-layer MLP with tensor parallel support."""
+
+ def __init__(
+ self,
+ dims: list[int],
+ activation,
+ bias: bool = True,
+ prefix: str = "",
+ use_data_parallel: bool = False,
+ ):
+ super().__init__()
+ assert len(dims) == 3
+ self.use_data_parallel = use_data_parallel
+ self.fc0 = ColumnParallelLinear(
+ dims[0],
+ dims[1],
+ bias=bias,
+ prefix=maybe_prefix(prefix, "fc0"),
+ disable_tp=self.use_data_parallel,
+ )
+ self.fc1 = RowParallelLinear(
+ dims[1],
+ dims[2],
+ bias=bias,
+ prefix=maybe_prefix(prefix, "fc1"),
+ disable_tp=self.use_data_parallel,
+ )
+ self.activation = activation
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, _ = self.fc0(x)
+ x = self.activation(x)
+ x, _ = self.fc1(x)
+ return x
+
+
+class MoonViTEncoderLayer(nn.Module):
+ """Single encoder layer for MoonViT with TP/DP support."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ hidden_dim: int,
+ mlp_dim: int,
+ prefix: str = "",
+ *,
+ activation=F.gelu,
+ attn_bias: bool = False,
+ ):
+ super().__init__()
+ self.use_data_parallel = is_vit_use_data_parallel()
+
+ self.num_heads = num_heads
+ self.hidden_dim = hidden_dim
+ self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
+ self.tp_size = (
+ 1 if self.use_data_parallel else get_tensor_model_parallel_world_size()
+ )
+ self.num_attention_heads_per_partition = divide(num_heads, self.tp_size)
+
+ self.norm0 = nn.LayerNorm(hidden_dim)
+ self.norm1 = nn.LayerNorm(hidden_dim)
+ self.mlp = MLP2(
+ [hidden_dim, mlp_dim, hidden_dim],
+ activation,
+ prefix=f"{prefix}.mlp",
+ use_data_parallel=self.use_data_parallel,
+ )
+ self.wqkv = QKVParallelLinear(
+ hidden_size=hidden_dim,
+ head_size=self.hidden_size_per_attention_head,
+ total_num_heads=num_heads,
+ total_num_kv_heads=num_heads,
+ bias=attn_bias,
+ prefix=f"{prefix}.wqkv",
+ disable_tp=self.use_data_parallel,
+ )
+ self.wo = RowParallelLinear(
+ hidden_dim,
+ hidden_dim,
+ bias=attn_bias,
+ prefix=f"{prefix}.wo",
+ disable_tp=self.use_data_parallel,
+ )
+ self.attn = MMEncoderAttention(
+ num_heads=self.num_attention_heads_per_partition,
+ head_size=self.hidden_size_per_attention_head,
+ scale=self.hidden_size_per_attention_head**-0.5,
+ prefix=f"{prefix}.attn",
+ )
+
+ def attention_qkvpacked(
+ self,
+ x: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rope_freqs_cis: torch.Tensor | None = None,
+ ):
+ """Compute self-attention with packed QKV.
+
+ Args:
+ x (torch.Tensor): (seqlen, hidden_dim)
+ cu_seqlens (torch.Tensor): cumulative sequence lengths
+ """
+ seq_length = x.size(0)
+ xqkv, _ = self.wqkv(x)
+
+ qkv_shape = xqkv.size()[:-1] + (
+ 3,
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ # xqkv: (seqlen, 3, nheads, headdim)
+ xqkv = xqkv.view(*qkv_shape)
+ xq, xk, xv = torch.unbind(xqkv, dim=-3)
+
+ xq, xk = apply_rope(xq, xk, rope_freqs_cis)
+
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_out = self.attn(
+ xq.unsqueeze(0),
+ xk.unsqueeze(0),
+ xv.unsqueeze(0),
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ )
+ attn_out = attn_out.reshape(
+ seq_length,
+ self.num_attention_heads_per_partition
+ * self.hidden_size_per_attention_head,
+ )
+ attn_out, _ = self.wo(attn_out)
+ return attn_out
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rope_freqs_cis: torch.Tensor | None = None,
+ ):
+ residual = hidden_states
+ hidden_states = self.norm0(hidden_states)
+
+ hidden_states = self.attention_qkvpacked(
+ hidden_states, cu_seqlens, rope_freqs_cis
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class MoonViT3dEncoder(nn.Module):
+ """Full encoder stack for MoonViT 3D."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_layers: int,
+ block_cfg: dict,
+ video_attn_type: str = "spatial_temporal",
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ assert video_attn_type == "spatial_temporal", (
+ f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
+ )
+ self.video_attn_type = video_attn_type
+ self.rope_2d = Rope2DPosEmbRepeated(
+ block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512
+ )
+ self.blocks = nn.ModuleList(
+ [
+ MoonViTEncoderLayer(
+ **block_cfg,
+ prefix=f"{prefix}.blocks.{layer_idx}",
+ )
+ for layer_idx in range(num_layers)
+ ]
+ )
+ self.final_layernorm = nn.LayerNorm(hidden_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ grid_thws: torch.Tensor,
+ ) -> torch.Tensor:
+ rope_freqs_cis = self.rope_2d.get_freqs_cis(
+ grid_thws=grid_thws, device=hidden_states.device
+ )
+
+ lengths = torch.cat(
+ (
+ torch.zeros(1, dtype=grid_thws.dtype, device=grid_thws.device),
+ grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2],
+ )
+ )
+
+ cu_seqlens = lengths.to(hidden_states.device).cumsum(dim=0, dtype=torch.int32)
+
+ for block in self.blocks:
+ hidden_states = block(
+ hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
+ )
+
+ hidden_states = self.final_layernorm(hidden_states)
+
+ return hidden_states
+
+
+def tpool_patch_merger(
+ x: torch.Tensor,
+ grid_thws: torch.Tensor,
+ merge_kernel_size: tuple[int, int] = (2, 2),
+) -> list[torch.Tensor]:
+ """Temporal pooling patch merger."""
+ kh, kw = merge_kernel_size
+ lengths = (grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2]).tolist()
+ seqs = x.split(lengths, dim=0)
+
+ outputs = []
+ for seq, (t, h, w) in zip(seqs, grid_thws.tolist()):
+ nh, nw = h // kh, w // kw
+ # Reshape: (t*h*w, d) -> (t, nh, kh, nw, kw, d)
+ v = seq.view(t, nh, kh, nw, kw, -1)
+ # Temporal pooling first (reduces tensor size before permute)
+ v = v.mean(dim=0) # (nh, kh, nw, kw, d)
+ # Spatial rearrangement: (nh, kh, nw, kw, d) -> (nh, nw, kh, kw, d)
+ out = v.permute(0, 2, 1, 3, 4).reshape(nh * nw, kh * kw, -1)
+ outputs.append(out)
+
+ return outputs
+
+
+class MoonViT3dPretrainedModel(nn.Module):
+ """Main vision tower model.
+
+ Uses KimiK25VisionConfig directly from transformers_utils/configs/kimi_k25.py.
+ """
+
+ def __init__(
+ self,
+ config: KimiK25VisionConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+ config = deepcopy(config)
+ self.config = config # Required for run_dp_sharded_mrope_vision_model
+ self.merge_kernel_size = config.merge_kernel_size
+ self.patch_size = config.patch_size
+ self.merge_type = config.merge_type
+
+ self.patch_embed = MoonVision3dPatchEmbed(
+ out_dim=config.hidden_size,
+ patch_size=config.patch_size,
+ pos_emb_height=config.init_pos_emb_height,
+ pos_emb_width=config.init_pos_emb_width,
+ pos_emb_time=config.init_pos_emb_time,
+ pos_emb_type=config.pos_emb_type,
+ )
+
+ self.encoder = MoonViT3dEncoder(
+ hidden_dim=config.hidden_size,
+ num_layers=config.num_hidden_layers,
+ block_cfg={
+ "num_heads": config.num_attention_heads,
+ "hidden_dim": config.hidden_size,
+ "mlp_dim": config.intermediate_size,
+ "activation": get_act_fn("gelu_pytorch_tanh"),
+ "attn_bias": True,
+ },
+ video_attn_type=config.video_attn_type,
+ prefix=maybe_prefix(prefix, "encoder"),
+ )
+
+ def forward(
+ self, pixel_values: torch.Tensor, grid_thws: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (torch.Tensor): The input pixel values.
+ grid_thws (torch.Tensor): Temporal, height and width.
+
+ Returns:
+ torch.Tensor: The output tokens.
+ """
+ hidden_states = self.patch_embed(pixel_values, grid_thws)
+ hidden_states = self.encoder(hidden_states, grid_thws)
+ if (
+ self.merge_type == "sd2_tpool"
+ ): # spatial downsampling 2x with temporal pooling all
+ hidden_states = tpool_patch_merger(
+ hidden_states, grid_thws, merge_kernel_size=self.merge_kernel_size
+ )
+ else:
+ raise NotImplementedError(f"Not support {self.merge_type}")
+
+ return hidden_states
+
+
+@torch.inference_mode()
+def mm_projector_forward(mm_projector: torch.nn.Module, vt_output: list[torch.Tensor]):
+ """Apply MM projector to vision tower outputs."""
+ num_embedding_list = [x.shape[0] for x in vt_output]
+ batched = torch.cat(vt_output, dim=0)
+ proj_out = mm_projector(batched)
+ proj_out = proj_out.reshape(-1, proj_out.shape[-1])
+ proj_out = torch.split(proj_out, num_embedding_list)
+ return proj_out
+
+
+@torch.inference_mode()
+def vision_tower_forward(
+ vision_tower: Any,
+ pixel_values: torch.Tensor,
+ grid_thw: torch.Tensor,
+ mm_projector: Any,
+ use_data_parallel: bool,
+) -> list[torch.Tensor]:
+ """DP-sharded vision tower forward with mrope.
+
+ Uses vLLM's standard data parallelism utility to shard the batch
+ across available GPUs, enabling parallel processing of vision features.
+ """
+ if use_data_parallel:
+ grid_thw_list = grid_thw.tolist()
+ vt_outputs = run_dp_sharded_mrope_vision_model(
+ vision_model=vision_tower,
+ pixel_values=pixel_values,
+ grid_thw_list=grid_thw_list,
+ rope_type="rope_2d",
+ )
+ else:
+ vt_outputs = vision_tower(pixel_values, grid_thw)
+ tensors = mm_projector_forward(mm_projector, list(vt_outputs))
+ return list(tensors)
+
+
+class KimiK25MultiModalProjector(nn.Module):
+ """Multi-modal projector with patch merging for Kimi-K2.5."""
+
+ def __init__(
+ self,
+ config: KimiK25VisionConfig,
+ use_data_parallel: bool = False,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.use_data_parallel = use_data_parallel
+
+ # Hidden size after patch merging
+ merge_h, merge_w = config.merge_kernel_size
+ self.hidden_size = config.hidden_size * merge_h * merge_w
+
+ self.pre_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-5)
+ self.linear_1 = ReplicatedLinear(
+ self.hidden_size,
+ self.hidden_size,
+ bias=True,
+ prefix=maybe_prefix(prefix, "linear_1"),
+ )
+ self.linear_2 = ReplicatedLinear(
+ self.hidden_size,
+ config.mm_hidden_size,
+ bias=True,
+ prefix=maybe_prefix(prefix, "linear_2"),
+ )
+ self.act = GELUActivation()
+
+ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
+ hidden_states, _ = self.linear_1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states, _ = self.linear_2(hidden_states)
+ return hidden_states
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 25b6e4025..6d70a8b5d 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -359,6 +359,7 @@ _MULTIMODAL_MODELS = {
),
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
+ "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501
"LightOnOCRForConditionalGeneration": (
"lightonocr",
"LightOnOCRForConditionalGeneration",
diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py
index 7b1215876..8ce1e3587 100644
--- a/vllm/multimodal/inputs.py
+++ b/vllm/multimodal/inputs.py
@@ -20,6 +20,7 @@ from typing import (
)
import numpy as np
+from PIL.Image import Image
from typing_extensions import NotRequired, TypeVar
from vllm.utils.collection_utils import full_groupby, is_list_of
@@ -29,7 +30,6 @@ from vllm.utils.jsontree import json_map_leaves
if TYPE_CHECKING:
import torch
import torch.types
- from PIL.Image import Image
from transformers.feature_extraction_utils import BatchFeature
from .media import MediaWithBytes
@@ -105,6 +105,28 @@ The number of data items allowed per modality is restricted by
"""
+class VisionChunkImage(TypedDict):
+ """Represents an image wrapped as a vision chunk."""
+
+ type: Literal["image"]
+ image: Image
+ uuid: str | None
+
+
+class VisionChunkVideo(TypedDict):
+ """Represents a video chunk with metadata."""
+
+ type: Literal["video_chunk"]
+ video_chunk: list[Image]
+ uuid: str | None
+ prompt: str
+ video_idx: int
+
+
+VisionChunk = VisionChunkImage | VisionChunkVideo
+"""A vision chunk is either an image or a video chunk."""
+
+
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
@@ -118,6 +140,9 @@ class MultiModalDataBuiltins(TypedDict, total=False):
audio: ModalityData[AudioItem]
"""The input audio(s)."""
+ vision_chunk: ModalityData[VisionChunk]
+ """The input visual atom(s) - unified modality for images and video chunks."""
+
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""
diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py
index a8ebd427b..638478125 100644
--- a/vllm/multimodal/parse.py
+++ b/vllm/multimodal/parse.py
@@ -384,6 +384,13 @@ class VideoEmbeddingItems(EmbeddingItems):
super().__init__(data, "video", expected_hidden_size)
+class VisionChunkProcessorItems(ProcessorBatchItems[Any]):
+ """Processor items for vision chunks (unified image and video chunks)."""
+
+ def __init__(self, data: Sequence[Any]) -> None:
+ super().__init__(data, "vision_chunk")
+
+
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
@@ -652,11 +659,23 @@ class MultiModalDataParser:
return VideoProcessorItems(new_videos, metadata=metadata_lst)
+ def _parse_vision_chunk_data(
+ self,
+ data: ModalityData[Any],
+ ) -> ModalityDataItems[Any, Any] | None:
+ """Parse vision chunk data (unified image and video chunks)."""
+ if data is None or self._is_empty(data):
+ return None
+ if self.is_embeddings(data):
+ raise ValueError("Do not support embedding data for vision_chunk right now")
+ return VisionChunkProcessorItems(data)
+
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
return {
"audio": self._parse_audio_data,
"image": self._parse_image_data,
"video": self._parse_video_data,
+ "vision_chunk": self._parse_vision_chunk_data,
}
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py
index f123799ca..9c7b9463b 100644
--- a/vllm/multimodal/video.py
+++ b/vllm/multimodal/video.py
@@ -235,6 +235,27 @@ class VideoLoader:
VIDEO_LOADER_REGISTRY = ExtensionManager()
+@VIDEO_LOADER_REGISTRY.register("identity")
+class IdentityVideoLoader(VideoLoader):
+ """IdentityVideoLoader returns raw video bytes without decoding.
+
+ This allows the model processor to handle video decoding and
+ is required for models like Kimi-K2.5 that need custom video chunk splitting.
+
+ NOTE: This is temporary for Kimi-K2.5 testing. Remember to change back
+ to opencv before release if needed.
+ """
+
+ @classmethod
+ def load_bytes(
+ cls,
+ data: bytes,
+ num_frames: int = -1,
+ **kwargs: Any,
+ ) -> tuple[Any, Any]:
+ return data, None
+
+
@VIDEO_LOADER_REGISTRY.register("opencv")
class OpenCVVideoBackend(VideoLoader):
def get_cv2_video_api(self):
diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py
index 7b918d2e3..05bc90e2e 100644
--- a/vllm/reasoning/__init__.py
+++ b/vllm/reasoning/__init__.py
@@ -53,8 +53,8 @@ _REASONING_PARSERS_TO_REGISTER = {
"HunyuanA13BReasoningParser",
),
"kimi_k2": (
- "deepseek_r1_reasoning_parser",
- "DeepSeekR1ReasoningParser",
+ "kimi_k2_reasoning_parser",
+ "KimiK2ReasoningParser",
),
"minimax_m2": (
"minimax_m2_reasoning_parser",
diff --git a/vllm/reasoning/kimi_k2_reasoning_parser.py b/vllm/reasoning/kimi_k2_reasoning_parser.py
new file mode 100644
index 000000000..42869585b
--- /dev/null
+++ b/vllm/reasoning/kimi_k2_reasoning_parser.py
@@ -0,0 +1,80 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Sequence
+from typing import TYPE_CHECKING, Any
+
+from transformers import PreTrainedTokenizerBase
+
+from vllm.entrypoints.openai.engine.protocol import DeltaMessage
+from vllm.logger import init_logger
+from vllm.reasoning import ReasoningParser
+from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
+
+from .identity_reasoning_parser import IdentityReasoningParser
+
+if TYPE_CHECKING:
+ from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+ )
+else:
+ ChatCompletionRequest = Any
+
+
+logger = init_logger(__name__)
+
+
+class KimiK2ReasoningParser(ReasoningParser):
+ """
+ Kimi K2 parser that delegates to either DeepSeekR1ReasoningParser or
+ IdentityReasoningParser based on `thinking` and `separate_reasoning`.
+
+ Unlike DeepSeekV3ReasoningParser which defaults to NOT thinking,
+ KimiK2ReasoningParser defaults to thinking mode (uses DeepSeekR1ReasoningParser).
+ """
+
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
+ super().__init__(tokenizer, *args, **kwargs)
+
+ chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {}
+ # Key difference: default to True instead of False
+ thinking = bool(chat_kwargs.pop("thinking", True))
+
+ if thinking:
+ self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
+ else:
+ self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
+
+ def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
+ return self._parser.is_reasoning_end(input_ids)
+
+ def is_reasoning_end_streaming(
+ self, input_ids: list[int], delta_ids: list[int]
+ ) -> bool:
+ return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
+
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
+ return self._parser.extract_content_ids(input_ids)
+
+ def extract_reasoning(
+ self, model_output: str, request: "ChatCompletionRequest"
+ ) -> tuple[str | None, str | None]:
+ return self._parser.extract_reasoning(model_output, request)
+
+ def extract_reasoning_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> DeltaMessage | None:
+ return self._parser.extract_reasoning_streaming(
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ delta_token_ids,
+ )
diff --git a/vllm/renderers/hf.py b/vllm/renderers/hf.py
index d2252c655..e159a04b9 100644
--- a/vllm/renderers/hf.py
+++ b/vllm/renderers/hf.py
@@ -20,9 +20,11 @@ from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
ChatTemplateResolutionError,
ConversationMessage,
+ build_video_prompts_from_mm_data,
load_chat_template,
parse_chat_messages,
parse_chat_messages_async,
+ rebuild_mm_uuids_from_mm_data,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.logger import init_logger
@@ -547,6 +549,40 @@ class HfRenderer(RendererLike):
**kwargs,
)
+ # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
+ # model which uses unified vision chunks for both images and videos.
+ if (
+ getattr(model_config.hf_config, "use_unified_vision_chunk", False)
+ and mm_uuids is not None
+ and mm_data is not None
+ ):
+ mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
+
+ # get video placehoder, replace it with runtime video-chunk prompts
+ video_placeholder = getattr(
+ model_config.hf_config, "video_placeholder", None
+ )
+ if video_placeholder and isinstance(prompt_raw, str):
+ video_prompts = build_video_prompts_from_mm_data(mm_data)
+
+ # replace in order
+ prompt_raw_parts = prompt_raw.split(video_placeholder)
+ if len(prompt_raw_parts) == len(video_prompts) + 1:
+ prompt_raw = "".join(
+ [
+ prompt_raw_parts[i] + video_prompts[i]
+ for i in range(len(video_prompts))
+ ]
+ )
+ prompt_raw += prompt_raw_parts[-1]
+ else:
+ logger.warning(
+ "Number of video placeholders (%d) does not match "
+ "number of videos (%d) in the request.",
+ len(prompt_raw_parts) - 1,
+ len(video_prompts),
+ )
+
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
@@ -587,6 +623,40 @@ class HfRenderer(RendererLike):
**kwargs,
)
+ # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
+ # model which uses unified vision chunks for both images and videos.
+ if (
+ getattr(model_config.hf_config, "use_unified_vision_chunk", False)
+ and mm_uuids is not None
+ and mm_data is not None
+ ):
+ mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
+
+ # get video placehoder, replace it with runtime video-chunk prompts
+ video_placeholder = getattr(
+ model_config.hf_config, "video_placeholder", None
+ )
+ if video_placeholder and isinstance(prompt_raw, str):
+ video_prompts = build_video_prompts_from_mm_data(mm_data)
+
+ # replace in order
+ prompt_raw_parts = prompt_raw.split(video_placeholder)
+ if len(prompt_raw_parts) == len(video_prompts) + 1:
+ prompt_raw = "".join(
+ [
+ prompt_raw_parts[i] + video_prompts[i]
+ for i in range(len(video_prompts))
+ ]
+ )
+ prompt_raw += prompt_raw_parts[-1]
+ else:
+ logger.warning(
+ "Number of video placeholders (%d) does not match "
+ "number of videos (%d) in the request.",
+ len(prompt_raw_parts) - 1,
+ len(video_prompts),
+ )
+
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index a009017e5..277afa7dd 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -81,6 +81,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
isaac="IsaacConfig",
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
+ kimi_k25="KimiK25Config",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct)
jais="JAISConfig",
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 00d5ecd25..bfb9c1758 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -38,6 +38,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"MoonViTConfig": "vllm.transformers_utils.configs.moonvit",
"KimiLinearConfig": "vllm.transformers_utils.configs.kimi_linear",
"KimiVLConfig": "vllm.transformers_utils.configs.kimi_vl",
+ "KimiK25Config": "vllm.transformers_utils.configs.kimi_k25",
"NemotronConfig": "vllm.transformers_utils.configs.nemotron",
"NemotronHConfig": "vllm.transformers_utils.configs.nemotron_h",
"Olmo3Config": "vllm.transformers_utils.configs.olmo3",
@@ -77,6 +78,7 @@ __all__ = [
"MoonViTConfig",
"KimiLinearConfig",
"KimiVLConfig",
+ "KimiK25Config",
"NemotronConfig",
"NemotronHConfig",
"Olmo3Config",
diff --git a/vllm/transformers_utils/configs/kimi_k25.py b/vllm/transformers_utils/configs/kimi_k25.py
new file mode 100644
index 000000000..72f67251d
--- /dev/null
+++ b/vllm/transformers_utils/configs/kimi_k25.py
@@ -0,0 +1,129 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Kimi-K2.5 Model Configuration.
+
+This configuration supports video-chunk as an internal modality type.
+A video-chunk is the smallest independently processable unit of video.
+"""
+
+from transformers import DeepseekV3Config
+from transformers.configuration_utils import PretrainedConfig
+
+
+class KimiK25VisionConfig(PretrainedConfig):
+ model_type = "kimi_k25_vision"
+
+ def __init__(
+ self,
+ # Vision Tower
+ patch_size: int = 14,
+ init_pos_emb_height: int = 64,
+ init_pos_emb_width: int = 64,
+ init_pos_emb_time: int = 4,
+ pos_emb_type: str = "divided_fixed",
+ num_attention_heads: int = 16,
+ num_hidden_layers: int = 27,
+ hidden_size: int = 1152,
+ intermediate_size: int = 4304,
+ merge_kernel_size: tuple[int, int] = (2, 2),
+ video_attn_type: str = "spatial_temporal",
+ merge_type: str = "sd2_tpool",
+ # MM Projector
+ mm_projector_type: str = "patchmerger",
+ mm_hidden_size: int | None = None,
+ projector_hidden_act: str = "gelu",
+ projector_ln_eps: float = 1e-5,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ # Vision Tower
+ self.patch_size = patch_size
+ self.init_pos_emb_height = init_pos_emb_height
+ self.init_pos_emb_width = init_pos_emb_width
+ self.init_pos_emb_time = init_pos_emb_time
+ self.pos_emb_type = pos_emb_type
+ self.num_attention_heads = num_attention_heads
+ self.num_hidden_layers = num_hidden_layers
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.merge_kernel_size = merge_kernel_size
+ self.video_attn_type = video_attn_type
+ self.merge_type = merge_type
+ # MM Projector
+ self.mm_projector_type = mm_projector_type
+ if mm_hidden_size is not None:
+ self.mm_hidden_size = mm_hidden_size
+ else:
+ self.mm_hidden_size = hidden_size
+ self.projector_hidden_act = projector_hidden_act
+ self.projector_ln_eps = projector_ln_eps
+
+
+class KimiK25Config(PretrainedConfig):
+ """Kimi-K2.5 model configuration.
+
+ Kimi-K2.5 extends Kimi-K2 with vision support using video-chunks.
+ A video-chunk consists of multiple consecutive frames
+ that are processed together with temporal pooling.
+
+ Args:
+ vision_config: Configuration for the vision tower and projector.
+ text_config: Configuration for the text model (DeepseekV3).
+ ignore_index: The ignore index for the loss function.
+ media_placeholder_token_id: The token ID for media placeholders.
+ pad_token_id: The token ID for padding.
+ """
+
+ model_type = "kimi_k25"
+
+ def __init__(
+ self,
+ vision_config: dict | KimiK25VisionConfig | None = None,
+ text_config: dict | DeepseekV3Config | None = None,
+ ignore_index: int = -100,
+ media_placeholder_token_id: int = 163605,
+ pad_token_id: int = 0,
+ use_unified_vision_chunk: bool = False,
+ video_placeholder: str = "<|kimi_k25_video_placeholder|>",
+ **kwargs,
+ ):
+ # Vision config
+ if vision_config is None:
+ vision_config = KimiK25VisionConfig()
+ elif isinstance(vision_config, dict):
+ vision_config = KimiK25VisionConfig(**vision_config)
+ self.vision_config: KimiK25VisionConfig = vision_config
+
+ # Text config
+ if text_config is None:
+ text_config = DeepseekV3Config()
+ elif isinstance(text_config, dict):
+ text_config = DeepseekV3Config(**text_config)
+ self.text_config: DeepseekV3Config = text_config
+
+ # Set mm_hidden_size to text hidden size if not explicitly set
+ if self.vision_config.mm_hidden_size == self.vision_config.hidden_size:
+ self.vision_config.mm_hidden_size = self.text_config.hidden_size
+
+ # Other config
+ self.ignore_index = ignore_index
+ self.media_placeholder_token_id = media_placeholder_token_id
+ self.use_unified_vision_chunk = use_unified_vision_chunk
+ self.video_placeholder = video_placeholder
+
+ # Propagate quantization config from text model
+ if getattr(self.text_config, "quantization_config", None) is not None:
+ self.quantization_config = self.text_config.quantization_config
+
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ @property
+ def hidden_size(self) -> int:
+ """Get hidden size from text config for compatibility."""
+ return self.text_config.hidden_size
+
+ @property
+ def vocab_size(self) -> int:
+ """Get vocab size from text config for compatibility."""
+ return self.text_config.vocab_size