diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index a38a88ce3..6d5ba168b 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -686,6 +686,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | | `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | | `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | +| `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I+ | `moonshotai/Kimi-K2.5` | | ✅︎ | | `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I+ | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ | | `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I+ | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 5c6db71b1..bea86c081 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -771,6 +771,11 @@ _MULTIMODAL_EXAMPLE_MODELS = { ) }, ), + "KimiK25ForConditionalGeneration": _HfExamplesInfo( + "moonshotai/Kimi-K2.5", + trust_remote_code=True, + is_available_online=False, + ), "LightOnOCRForConditionalGeneration": _HfExamplesInfo( "lightonai/LightOnOCR-1B-1025" ), diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index b4f2e9dbe..bcc8d3c65 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -46,6 +46,9 @@ from vllm.multimodal.inputs import ( MultiModalBatchedField, MultiModalFlatField, MultiModalSharedField, + VisionChunk, + VisionChunkImage, + VisionChunkVideo, ) from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector @@ -336,7 +339,9 @@ ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] ChatTemplateContentFormat = Literal["string", "openai"] -ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"] +ModalityStr = Literal[ + "image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk" +] _T = TypeVar("_T") @@ -449,6 +454,78 @@ def _get_embeds_data( raise NotImplementedError(type(data_items)) +def rebuild_mm_uuids_from_mm_data( + mm_uuids: MultiModalUUIDDict, + mm_data: MultiModalDataDict, +) -> MultiModalUUIDDict: + """Rebuild mm_uuids after vision_chunk processing. + + When videos are split into chunks, the original UUIDs need to be updated + to reflect the new UUIDs generated for each chunk. + + Args: + mm_uuids: Original UUIDs dictionary + mm_data: Processed multimodal data with vision_chunk items + + Returns: + Updated UUIDs dictionary with chunk UUIDs + """ + vision_chunks = mm_data.get("vision_chunk") + if vision_chunks is None: + return mm_uuids + + new_uuids = dict(mm_uuids) + vision_chunk_uuids = [] + + for item in vision_chunks: + # vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo) + assert isinstance(item, dict) + uuid_val = item.get("uuid") + if uuid_val is not None: + vision_chunk_uuids.append(uuid_val) + + if vision_chunk_uuids: + new_uuids["vision_chunk"] = vision_chunk_uuids + + return new_uuids + + +def build_video_prompts_from_mm_data( + mm_data: MultiModalDataDict, +) -> list[str]: + """Build video prompts from vision_chunk data. + + Collects prompts from video chunks and groups them by video_idx. + + Args: + mm_data: Processed multimodal data with vision_chunk items + + Returns: + List of video prompts, one per video. + """ + vision_chunks = mm_data.get("vision_chunk") + if vision_chunks is None: + return [] + + # Group chunks by video_idx + video_prompts_dict: dict[int, list[str]] = defaultdict(list) + + for item in vision_chunks: + # vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo) + assert isinstance(item, dict) + if item.get("type") == "video_chunk": + video_idx = item.get("video_idx", 0) + prompt = item.get("prompt", "") + video_prompts_dict[video_idx].append(prompt) + + # Build prompts in video order + video_prompts = [] + for video_idx in sorted(video_prompts_dict.keys()): + video_prompts.append("".join(video_prompts_dict[video_idx])) + + return video_prompts + + class BaseMultiModalItemTracker(ABC, Generic[_T]): """ Tracks multi-modal items in a given request and ensures that the number @@ -462,6 +539,13 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self._model_config = model_config self._items_by_modality = defaultdict[str, list[_T]](list) + # Track original modality for each vision_chunk item (image or video) + self._modality_order = defaultdict[str, list[str]](list) + + @cached_property + def use_unified_vision_chunk_modality(self) -> bool: + """Check if model uses unified vision_chunk modality for images/videos.""" + return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False) @property def model_config(self) -> ModelConfig: @@ -499,11 +583,31 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): media. """ input_modality = modality.replace("_embeds", "") - num_items = len(self._items_by_modality[modality]) + 1 + original_modality = modality + use_vision_chunk = ( + self.use_unified_vision_chunk_modality + and original_modality in ["video", "image"] + ) + + # If use_unified_vision_chunk_modality is enabled, + # map image/video to vision_chunk + if use_vision_chunk: + # To avoid validation fail + # because models with use_unified_vision_chunk_modality=True + # will only accept vision_chunk modality. + input_modality = "vision_chunk" + num_items = len(self._items_by_modality[input_modality]) + 1 + else: + num_items = len(self._items_by_modality[original_modality]) + 1 self.mm_processor.validate_num_items(input_modality, num_items) - self._items_by_modality[modality].append(item) + # Track original modality for vision_chunk items + if use_vision_chunk: + self._items_by_modality[input_modality].append(item) # type: ignore + self._modality_order["vision_chunk"].append(original_modality) + else: + self._items_by_modality[original_modality].append(item) return self.model_cls.get_placeholder_str(modality, num_items) @@ -515,6 +619,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def _resolve_items( items_by_modality: dict[str, list[tuple[object, str | None]]], mm_processor: BaseMultiModalProcessor, + vision_chunk_modality_order: dict[str, list[str]], ) -> tuple[MultiModalDataDict, MultiModalUUIDDict]: if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") @@ -546,6 +651,74 @@ def _resolve_items( if "video" in items_by_modality: mm_data["video"] = [data for data, uuid in items_by_modality["video"]] mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]] + if "vision_chunk" in items_by_modality: + # Process vision_chunk items - extract from (data, modality) tuples + # and convert to VisionChunk types with proper UUID handling + vision_chunk_items = items_by_modality["vision_chunk"] + modality_order = vision_chunk_modality_order.get("vision_chunk", []) + mm_uuids["vision_chunk"] = [ + uuid for data, uuid in items_by_modality["vision_chunk"] + ] + + # Filter out None items (from asyncio.sleep(0) placeholders) + filtered_items = [ + (idx, item) + for idx, item in enumerate(vision_chunk_items) + if item is not None + ] + + assert len(filtered_items) == len(modality_order), ( + f"vision_chunk items ({len(filtered_items)}) and " + f"modality_order ({len(modality_order)}) must have same length" + ) + + processed_chunks: list[VisionChunk] = [] + video_idx = 0 + for i, (idx, item) in enumerate(filtered_items): + inner_modality = modality_order[i] + data, uuid = item + uuid_val = uuid if idx < len(mm_uuids["vision_chunk"]) else None + if inner_modality == "image": + # Cast data to proper type for image + # Use .media (PIL.Image) directly to avoid redundant + # bytes→PIL conversion in media_processor + if hasattr(data, "media"): + image_data = data.media # type: ignore[union-attr] + processed_chunks.append( + VisionChunkImage(type="image", image=image_data, uuid=uuid_val) + ) + else: + processed_chunks.append(data) # type: ignore[arg-type] + elif inner_modality == "video": + # For video, we may need to split into chunks + # if processor supports it + # For now, just wrap as a video chunk placeholder + if hasattr(mm_processor, "split_video_chunks") and data is not None: + try: + video_uuid = uuid_val or random_uuid() + # video await result is (video_data, video_meta) tuple + if isinstance(data, tuple) and len(data) >= 1: + video_data = data[0] + else: + video_data = data + video_chunks = mm_processor.split_video_chunks(video_data) + for i, vc in enumerate(video_chunks): + processed_chunks.append( + VisionChunkVideo( + type="video_chunk", + video_chunk=vc["video_chunk"], + uuid=f"{video_uuid}-{i}", + video_idx=video_idx, + prompt=vc["prompt"], + ) + ) + video_idx += 1 + except Exception as e: + logger.warning("Failed to split video chunks: %s", e) + processed_chunks.append(data) # type: ignore[arg-type] + else: + processed_chunks.append(data) # type: ignore[arg-type] + mm_data["vision_chunk"] = processed_chunks return mm_data, mm_uuids @@ -557,7 +730,9 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]] if not self._items_by_modality: return None, None - return _resolve_items(dict(self._items_by_modality), self.mm_processor) + return _resolve_items( + dict(self._items_by_modality), self.mm_processor, self._modality_order + ) def create_parser(self) -> "BaseMultiModalContentParser": return MultiModalContentParser(self) @@ -577,7 +752,9 @@ class AsyncMultiModalItemTracker( for modality, coros in self._items_by_modality.items() } - return _resolve_items(resolved_items_by_modality, self.mm_processor) + return _resolve_items( + resolved_items_by_modality, self.mm_processor, self._modality_order + ) def create_parser(self) -> "BaseMultiModalContentParser": return AsyncMultiModalContentParser(self) diff --git a/vllm/envs.py b/vllm/envs.py index d6243c02d..4541de72c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -782,6 +782,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Backend for Video IO # - "opencv": Default backend that uses OpenCV stream buffered backend. + # - "identity": Returns raw video bytes for model processor to handle. # # Custom backend implementations can be registered # via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and diff --git a/vllm/model_executor/models/kimi_k25.py b/vllm/model_executor/models/kimi_k25.py new file mode 100644 index 000000000..dccf05c14 --- /dev/null +++ b/vllm/model_executor/models/kimi_k25.py @@ -0,0 +1,581 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 +""" +Kimi-K2.5 Model Implementation for vLLM. + +Kimi-K2.5 extends Kimi-K2 with vision support + +This module defines: +- KimiK25ProcessingInfo/KimiK25MultiModalProcessor: Processing logic +- KimiK25ForConditionalGeneration: Main model class +""" + +import copy +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass +from typing import Annotated, Any, Literal + +import torch +from torch import nn +from transformers import BatchFeature +from transformers.processing_utils import ProcessorMixin + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model +from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP +from vllm.model_executor.models.kimi_k25_vit import ( + KimiK25MultiModalProjector, + MoonViT3dPretrainedModel, + vision_tower_forward, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, + VisionChunk, + VisionChunkImage, + VisionChunkVideo, +) +from vllm.multimodal.parse import MultiModalDataItems, VisionChunkProcessorItems +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, +) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import KimiK25Config +from vllm.transformers_utils.processor import cached_get_image_processor +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix + +logger = init_logger(__name__) + + +# Dummy input dimensions for profiling. +@dataclass +class MaxImageTokenMeta: + width: int = 3000 + height: int = 3000 + + +class KimiK25MediaPixelInputs(TensorSchema): + """ + Media input schema for K2-VL model. + + Dimensions: + - np: Number of patches (flattened from all media items) + - ps: Patch size + - nm: Number of media items + """ + + type: Literal["pixel_values"] = "pixel_values" + + pixel_values: Annotated[ + torch.Tensor | list[torch.Tensor], + TensorShape("np", 3, "ps", "ps"), + ] + + grid_thws: Annotated[torch.Tensor, TensorShape("nm", 3)] + + +class MoonshotKimiVAutoProcessor(ProcessorMixin): + attributes = ["tokenizer"] + tokenizer_class = "AutoTokenizer" + + def __init__(self, media_processor=None, tokenizer=None): + super().__init__(tokenizer) + self.media_processor = media_processor + + # We do not support str input for text here + def __call__( + self, + vision_chunks: list[VisionChunk] | None = None, + *, + text: list[int], + **kwargs, + ) -> BatchFeature: + """ + Args: + vision_chunks: List of VisionChunk items to be processed. + For image: VisionChunkImage with type='image', image=PIL.Image + For video_chunk: VisionChunkVideo with type='video_chunk', video_chunk=list[PIL.Image] + text: The token ids to be fed to a model (required). + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- list of token ids to be fed to a model. + - **pixel_values** -- Pixel values to be fed to a model. Returned when `vision_chunks` is not `None`. + - **grid_thws** -- list of image 3D grid in LLM. Returned when `vision_chunks` is not `None`. + """ + mm_inputs = {} + if vision_chunks is not None: + assert isinstance(vision_chunks, list) + mm_inputs = self.media_processor.preprocess(vision_chunks) + # XXX: _apply_hf_processor_text_mm will call tolist() on input_ids + return BatchFeature( + data={ + "input_ids": torch.tensor([text]), + **mm_inputs, + } + ) + + +class KimiK25ProcessingInfo(BaseProcessingInfo): + """Processing information for Kimi-K2.5 model. + + Provides configuration and utilities for processing both + images and video-chunks. + """ + + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__(ctx) + self.hf_config = self.get_hf_config() + self.media_token_id = self.hf_config.media_placeholder_token_id + media_processor = cached_get_image_processor( + self.ctx.model_config.model, trust_remote_code=True + ) + self.media_processor = media_processor + self.hf_processor = MoonshotKimiVAutoProcessor( + media_processor=self.media_processor, + tokenizer=self.get_tokenizer(), + ) + self.media_tokens_calculator = self.media_processor.media_tokens_calculator + + def get_hf_processor(self): + return self.hf_processor + + def get_hf_config(self): + return self.ctx.get_hf_config(KimiK25Config) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + # None means unlimited + return {"vision_chunk": None} + + +class KimiK25DummyInputsBuilder(BaseDummyInputsBuilder[KimiK25ProcessingInfo]): + """Builds dummy inputs for Kimi-K2.5 model profiling.""" + + def __init__(self, info: KimiK25ProcessingInfo) -> None: + super().__init__(info) + self.media_token_id = self.info.media_token_id + self.frame_per_chunk = self.info.media_processor.num_frames_per_chunk + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]: + num_media = mm_counts.get("vision_chunk", 0) + return [self.media_token_id] * num_media + + def get_dummy_mm_items(self): + dummy_videos = self._get_dummy_images( + height=MaxImageTokenMeta.height, + width=MaxImageTokenMeta.width, + num_images=self.frame_per_chunk, + ) + + video_chunk_dummy_item = VisionChunkVideo( + type="video_chunk", video_chunk=dummy_videos + ) + video_chunk_num_tokens = self.info.media_tokens_calculator( + video_chunk_dummy_item + ) + + image_dummy_item = VisionChunkImage( + type="image", + image=self._get_dummy_images( + height=MaxImageTokenMeta.height, + width=MaxImageTokenMeta.width, + num_images=1, + )[0], + ) + image_num_tokens = self.info.media_tokens_calculator(image_dummy_item) + # return the larger one + if video_chunk_num_tokens >= image_num_tokens: + return [video_chunk_dummy_item] + else: + return [image_dummy_item] + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + # TODO: Support mm_options for vision_chunk to allow user configuration + dummy_items = self.get_dummy_mm_items() + return {"vision_chunk": dummy_items} + + +class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo]): + """Multi-modal processor for Kimi-K2.5. + + Handles both image and video-chunk modalities. + """ + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + """Indicates how to slice media input into multiple items. + + pixel_values: [N, 3, patch_size, patch_size], all patches collected from B medias + grid_thws: [B,3], each item: [N_t, N_h ,N_w], indicates the grid size in time/height/width direction + for current item. + + by multiplying [N_t, N_h ,N_w], we get the number of patches for each media item, thus we can slice + pixel_values by pixel_values[start:start + N_t*N_h*N_w] to get patches of one item. + + """ + grid_thws = hf_inputs.get("grid_thws", torch.empty((0, 3))) + grid_sizes = grid_thws.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "vision_chunk", grid_sizes + ), + grid_thws=MultiModalFieldConfig.batched("vision_chunk"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + media_token_id = hf_config.media_placeholder_token_id + + def get_replacement(item_idx: int): + media = mm_items.get_items("vision_chunk", (VisionChunkProcessorItems,)) + num_media_token = self.info.media_tokens_calculator(media[item_idx]) + return [media_token_id] * num_media_token + + return [ + PromptReplacement( + modality="vision_chunk", + target=[media_token_id], + replacement=get_replacement, + ), + ] + + def split_video_chunks(self, video): + return self.info.media_processor.split_video_chunks(video) + + +@MULTIMODAL_REGISTRY.register_processor( + KimiK25MultiModalProcessor, + info=KimiK25ProcessingInfo, + dummy_inputs=KimiK25DummyInputsBuilder, +) +class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + """Kimi-K2.5 model for conditional generation. + + Supports both image and video-chunk modalities. + Video-chunks are temporal segments (typically 4 frames) that are + processed with temporal pooling. + """ + + supports_encoder_tp_data = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + # Kimi-K2.5 uses video_chunk for all media types + if modality == "image": + return "<|media_begin|>image<|media_content|><|media_pad|><|media_end|>" + elif modality == "video": + # return a placeholder, to be replaced in the future. + return "<|kimi_k25_video_placeholder|>" + + raise ValueError(f"Unsupported modality: {modality}") + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + model_config = vllm_config.model_config + config: KimiK25Config = model_config.hf_config + self.config = config + quant_config = vllm_config.quant_config + + # Check for MoonViT config compatibility + self.use_data_parallel = ( + model_config.multimodal_config.mm_encoder_tp_mode == "data" + ) + self.hidden_size = config.text_config.hidden_size + self.device = torch.cuda.current_device() + # Build vision tower directly with KimiK25VisionConfig + self.vision_tower = MoonViT3dPretrainedModel( + config.vision_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.vision_tower = self.vision_tower.to( + device=self.device, dtype=model_config.dtype + ) + + self.mm_projector = KimiK25MultiModalProjector( + config=config.vision_config, + use_data_parallel=self.use_data_parallel, + prefix=maybe_prefix(prefix, "mm_projector"), + ) + self.mm_projector = self.mm_projector.to( + device=self.device, dtype=model_config.dtype + ) + + self.quant_config = quant_config + sub_vllm_config = copy.deepcopy(vllm_config) + sub_vllm_config.model_config.hf_config = ( + sub_vllm_config.model_config.hf_config.text_config + ) + self.language_model = DeepseekV2Model( + vllm_config=sub_vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.text_config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) + self.media_placeholder: int = self.config.media_placeholder_token_id + + def _parse_and_validate_media_input( + self, **kwargs: object + ) -> KimiK25MediaPixelInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + grid_thws = kwargs.pop("grid_thws", None) + if pixel_values is None: + return None + + if isinstance(pixel_values, list): + pixel_values = torch.cat(pixel_values, dim=0) + + if len(pixel_values.shape) == 5 or len(pixel_values.shape) == 3: + pixel_values = pixel_values.reshape( + pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:] + ) + + # The batch dimension of pixel_values has been flattened into shape[0] + target_dtype = next(self.vision_tower.parameters()).dtype + pixel_values = pixel_values.to(target_dtype) + assert isinstance(grid_thws, torch.Tensor), ( + f"expect grid_thws to be a tensor, get {type(grid_thws)}" + ) + # In some cases (e.g. with merger), grid_thws has an extra middle dimension + grid_thws = grid_thws.reshape(-1, grid_thws.shape[-1]) + assert grid_thws.ndim == 2 and grid_thws.size(1) == 3, ( + f"unexpected shape for grid_thws: {grid_thws.shape}" + ) + + return KimiK25MediaPixelInputs( + type="pixel_values", + pixel_values=pixel_values, + grid_thws=grid_thws, + ) + + def _process_media_input( + self, media_input: KimiK25MediaPixelInputs + ) -> list[torch.Tensor]: + # NOTE(moyan): This forward will automatically batch the forward pass internally + media_features = vision_tower_forward( + self.vision_tower, + media_input["pixel_values"], + media_input["grid_thws"], + mm_projector=self.mm_projector, + use_data_parallel=self.use_data_parallel, + ) + return media_features + + def embed_multimodal(self, **kwargs: object) -> NestedTensors | None: + # Validate the multimodal input keyword arguments + media_input = self._parse_and_validate_media_input(**kwargs) + if media_input is None: + return None + + # Run multimodal inputs through encoder and projector + vision_embeddings = self._process_media_input(media_input) + return vision_embeddings + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, **kwargs) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + config = self.config.text_config + if not getattr(config, "n_routed_experts", None): + return [] + return SharedFusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=config.n_routed_experts, + num_redundant_experts=0, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + config = self.config.text_config + _KEYS_TO_MODIFY_MAPPING = { + "language_model.lm_head": "lm_head", + "language_model.model": "language_model", + # mm_projector -> mm_projector mapping + # "mm_projector": "mm_projector", + "mm_projector.proj.0": "mm_projector.linear_1", + "mm_projector.proj.2": "mm_projector.linear_2", + } + stacked_params_mapping = [ + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + if getattr(config, "kv_lora_rank", None) and getattr( + config, "q_lora_rank", None + ): + stacked_params_mapping += [ + (".fused_qkv_a_proj", ".q_a_proj", 0), + (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1), + ] + expert_params_mapping = self.get_expert_mapping() + + params_dict = dict(self.named_parameters()) + + for args in weights: + name, loaded_weight = args[:2] + kwargs = args[2] if len(args) > 2 else {} + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + continue + + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + + use_default_weight_loading = False + if "vision" in name: + if self.vision_tower is not None: + use_default_weight_loading = True + else: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id, **kwargs) + break + else: + for _, ( + param_name, + weight_name, + expert_id, + shard_id, + ) in enumerate(expert_params_mapping): + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + expert_id=expert_id, + shard_id=shard_id, + **kwargs, + ) + break + else: + use_default_weight_loading = True + + if use_default_weight_loading: + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, **kwargs) + + +def get_spec_layer_idx_from_weight_name( + config: KimiK25Config, weight_name: str +) -> int | None: + if hasattr(config, "num_nextn_predict_layers") and ( + config.num_nextn_predict_layers > 0 + ): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + # might start with language_model.model.layers + if f"model.layers.{layer_idx + i}." in weight_name: + return layer_idx + i + return None diff --git a/vllm/model_executor/models/kimi_k25_vit.py b/vllm/model_executor/models/kimi_k25_vit.py new file mode 100644 index 000000000..650ff7d21 --- /dev/null +++ b/vllm/model_executor/models/kimi_k25_vit.py @@ -0,0 +1,678 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Vision tower implementation for Kimi-K2.5 model. + +This module provides the vision encoder components for Kimi-K2.5, +including 3D patch embedding, RoPE position embedding, and +temporal pooling for video chunks. +""" + +from collections.abc import Sequence +from copy import deepcopy +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.activations import GELUActivation + +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.models.vision import ( + is_vit_use_data_parallel, + run_dp_sharded_mrope_vision_model, +) +from vllm.transformers_utils.configs.kimi_k25 import KimiK25VisionConfig + +logger = init_logger(__name__) + + +def _apply_rope_input_validation(x, freqs_cis): + assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) + assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape) + assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape) + assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype + + +def get_rope_shape_decorate(func): + _get_rope_shape_first_call_flag = set() + + def wrapper(org, interpolation_mode, shape): + key = (org.requires_grad, torch.is_grad_enabled(), interpolation_mode) + if key not in _get_rope_shape_first_call_flag: + _get_rope_shape_first_call_flag.add(key) + _ = func(org, interpolation_mode, shape=(64, 64)) + return func(org, interpolation_mode, shape) + + return wrapper + + +@get_rope_shape_decorate +@torch.compile(dynamic=True) +def get_rope_shape(org, interpolation_mode, shape): + return ( + F.interpolate( + org.permute((2, 0, 1)).unsqueeze(0), + size=shape, + mode=interpolation_mode, + ) + .squeeze(0) + .permute((1, 2, 0)) + .flatten(end_dim=1) + ) + + +def apply_rope( + xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: (The leading dimensions of all inputs should be the same) + xq: query, tensor of shape (..., num_heads, head_dim) + xk: key, tensor of shape (..., num_heads, head_dim) + freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. + Returns: + xq_out, xk_out: tensors of shape (..., num_heads, head_dim) + """ + _apply_rope_input_validation(xq, freqs_cis) + _apply_rope_input_validation(xk, freqs_cis) + + freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2 + # ..., num_heads, head_dim/2 + xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """Generate 1D sincos positional embedding from grid positions.""" + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False): + """Generate 1D sincos positional embedding.""" + grid_t = np.arange(t_size, dtype=np.float32) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +class Learnable2DInterpPosEmbDivided_fixed(nn.Module): + """2D learnable position embedding with temporal extension.""" + + def __init__( + self, + height: int, + width: int, + num_frames: int, + dim: int, + interpolation_mode: str = "bicubic", + ) -> None: + super().__init__() + self.height = height + self.width = width + self.num_frames = num_frames + self.dim = dim + self.interpolation_mode = interpolation_mode + self.weight = nn.Parameter(torch.empty(height, width, dim)) + self.register_buffer( + "time_weight", + torch.from_numpy(get_1d_sincos_pos_embed(self.dim, self.num_frames)) + .float() + .unsqueeze(1), + persistent=False, + ) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.weight) + + def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor: + pos_embs = [] + for t, h, w in grid_thws.tolist(): + assert t <= self.num_frames, f"t:{t} > self.num_frames:{self.num_frames}" + if (h, w) == self.weight.shape[:-1]: + pos_emb_2d = self.weight.flatten(end_dim=1) + else: + pos_emb_2d = get_rope_shape( + self.weight, + interpolation_mode=self.interpolation_mode, + shape=(h, w), + ) + + if t == 1: + pos_emb_3d = pos_emb_2d + else: + pos_emb_3d = ( + pos_emb_2d.unsqueeze(0).repeat(t, 1, 1) + self.time_weight[0:t] + ) + + pos_embs.append(pos_emb_3d.reshape(-1, pos_emb_3d.shape[-1])) + + out = x + torch.cat(pos_embs) + return out + + +class MoonVision3dPatchEmbed(nn.Module): + """3D patch embedding for vision tower.""" + + def __init__( + self, + out_dim: int, + in_dim: int = 3, + patch_size: int | tuple[int, int] = (14, 14), + pos_emb_height: int = 14, + pos_emb_width: int = 14, + pos_emb_time: int = 4, + pos_emb_type: str = "divided_fixed", + ): + super().__init__() + assert isinstance(patch_size, int | Sequence), ( + f"Invalid patch_size type: {type(patch_size)}" + ) + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + assert len(patch_size) == 2, ( + f"Expected patch_size to be a tuple of 2, got {patch_size}" + ) + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_dim, out_dim, kernel_size=patch_size, stride=patch_size + ) + + if pos_emb_type == "divided_fixed": + self.pos_emb = Learnable2DInterpPosEmbDivided_fixed( + height=pos_emb_height, + width=pos_emb_width, + num_frames=pos_emb_time, + dim=out_dim, + ) + else: + raise NotImplementedError(f"Not support pos_emb_type: {pos_emb_type}") + + def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor: + x = self.proj(x).view(x.size(0), -1) + # apply positional embedding + x = self.pos_emb(x, grid_thws) + return x + + +class Rope2DPosEmbRepeated(nn.Module): + """2D rotary position embedding with multi-resolution support.""" + + def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000): + super().__init__() + self.dim = dim + assert self.dim % 4 == 0, "dim must be divisible by 4" + self.max_height = max_height + self.max_width = max_width + self.theta_base = theta_base + + def extra_repr(self): + return ( + f"dim={self.dim}, max_height={self.max_height}, " + f"max_width={self.max_width}, theta_base={self.theta_base}" + ) + + def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor: + """Calculate the cis(freqs) for each position in the 2D grid.""" + N = self.max_height * self.max_width + flat_pos = torch.arange(0, N).float().to(device) + x_pos = flat_pos % self.max_width + y_pos = flat_pos // self.max_width + dim_range = ( + torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device) + ) # C/4 + freqs = 1.0 / (self.theta_base ** (dim_range / self.dim)) + x_freqs = torch.outer(x_pos, freqs).float() # N, C/4 + y_freqs = torch.outer(y_pos, freqs).float() # N, C/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4 + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4 + # N, C/4, 2 + freqs_cis = torch.cat( + [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1 + ) + # max_height, max_width, C/2 + freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1) + return freqs_cis + + def get_freqs_cis( + self, grid_thws: torch.Tensor, device: torch.device + ) -> torch.Tensor: + """ + Args: + grid_thws (torch.Tensor): grid time, height and width + + Returns: + freqs_cis: tensor of shape (sum(t * height * width), dim//2) + """ + if not hasattr(self, "freqs_cis"): + self.register_buffer( + "freqs_cis", self._precompute_freqs_cis(device), persistent=False + ) + + shapes = grid_thws.tolist() + assert all( + 1 <= h <= self.max_height and 1 <= w <= self.max_width for t, h, w in shapes + ), ( + shapes, + self.max_height, + self.max_width, + ) + freqs_cis = torch.cat( + [ + self.freqs_cis[:h, :w].reshape(-1, self.dim // 2).repeat(t, 1) + for t, h, w in shapes + ], + dim=0, + ) + return freqs_cis + + +class MLP2(nn.Module): + """Two-layer MLP with tensor parallel support.""" + + def __init__( + self, + dims: list[int], + activation, + bias: bool = True, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + assert len(dims) == 3 + self.use_data_parallel = use_data_parallel + self.fc0 = ColumnParallelLinear( + dims[0], + dims[1], + bias=bias, + prefix=maybe_prefix(prefix, "fc0"), + disable_tp=self.use_data_parallel, + ) + self.fc1 = RowParallelLinear( + dims[1], + dims[2], + bias=bias, + prefix=maybe_prefix(prefix, "fc1"), + disable_tp=self.use_data_parallel, + ) + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc0(x) + x = self.activation(x) + x, _ = self.fc1(x) + return x + + +class MoonViTEncoderLayer(nn.Module): + """Single encoder layer for MoonViT with TP/DP support.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + prefix: str = "", + *, + activation=F.gelu, + attn_bias: bool = False, + ): + super().__init__() + self.use_data_parallel = is_vit_use_data_parallel() + + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads + self.tp_size = ( + 1 if self.use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.num_attention_heads_per_partition = divide(num_heads, self.tp_size) + + self.norm0 = nn.LayerNorm(hidden_dim) + self.norm1 = nn.LayerNorm(hidden_dim) + self.mlp = MLP2( + [hidden_dim, mlp_dim, hidden_dim], + activation, + prefix=f"{prefix}.mlp", + use_data_parallel=self.use_data_parallel, + ) + self.wqkv = QKVParallelLinear( + hidden_size=hidden_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=attn_bias, + prefix=f"{prefix}.wqkv", + disable_tp=self.use_data_parallel, + ) + self.wo = RowParallelLinear( + hidden_dim, + hidden_dim, + bias=attn_bias, + prefix=f"{prefix}.wo", + disable_tp=self.use_data_parallel, + ) + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + scale=self.hidden_size_per_attention_head**-0.5, + prefix=f"{prefix}.attn", + ) + + def attention_qkvpacked( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rope_freqs_cis: torch.Tensor | None = None, + ): + """Compute self-attention with packed QKV. + + Args: + x (torch.Tensor): (seqlen, hidden_dim) + cu_seqlens (torch.Tensor): cumulative sequence lengths + """ + seq_length = x.size(0) + xqkv, _ = self.wqkv(x) + + qkv_shape = xqkv.size()[:-1] + ( + 3, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + # xqkv: (seqlen, 3, nheads, headdim) + xqkv = xqkv.view(*qkv_shape) + xq, xk, xv = torch.unbind(xqkv, dim=-3) + + xq, xk = apply_rope(xq, xk, rope_freqs_cis) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_out = self.attn( + xq.unsqueeze(0), + xk.unsqueeze(0), + xv.unsqueeze(0), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + attn_out = attn_out.reshape( + seq_length, + self.num_attention_heads_per_partition + * self.hidden_size_per_attention_head, + ) + attn_out, _ = self.wo(attn_out) + return attn_out + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rope_freqs_cis: torch.Tensor | None = None, + ): + residual = hidden_states + hidden_states = self.norm0(hidden_states) + + hidden_states = self.attention_qkvpacked( + hidden_states, cu_seqlens, rope_freqs_cis + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class MoonViT3dEncoder(nn.Module): + """Full encoder stack for MoonViT 3D.""" + + def __init__( + self, + hidden_dim: int, + num_layers: int, + block_cfg: dict, + video_attn_type: str = "spatial_temporal", + prefix: str = "", + ) -> None: + super().__init__() + + assert video_attn_type == "spatial_temporal", ( + f'video_attn_type must be "spatial_temporal", got {video_attn_type}' + ) + self.video_attn_type = video_attn_type + self.rope_2d = Rope2DPosEmbRepeated( + block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512 + ) + self.blocks = nn.ModuleList( + [ + MoonViTEncoderLayer( + **block_cfg, + prefix=f"{prefix}.blocks.{layer_idx}", + ) + for layer_idx in range(num_layers) + ] + ) + self.final_layernorm = nn.LayerNorm(hidden_dim) + + def forward( + self, + hidden_states: torch.Tensor, + grid_thws: torch.Tensor, + ) -> torch.Tensor: + rope_freqs_cis = self.rope_2d.get_freqs_cis( + grid_thws=grid_thws, device=hidden_states.device + ) + + lengths = torch.cat( + ( + torch.zeros(1, dtype=grid_thws.dtype, device=grid_thws.device), + grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2], + ) + ) + + cu_seqlens = lengths.to(hidden_states.device).cumsum(dim=0, dtype=torch.int32) + + for block in self.blocks: + hidden_states = block( + hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis + ) + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +def tpool_patch_merger( + x: torch.Tensor, + grid_thws: torch.Tensor, + merge_kernel_size: tuple[int, int] = (2, 2), +) -> list[torch.Tensor]: + """Temporal pooling patch merger.""" + kh, kw = merge_kernel_size + lengths = (grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2]).tolist() + seqs = x.split(lengths, dim=0) + + outputs = [] + for seq, (t, h, w) in zip(seqs, grid_thws.tolist()): + nh, nw = h // kh, w // kw + # Reshape: (t*h*w, d) -> (t, nh, kh, nw, kw, d) + v = seq.view(t, nh, kh, nw, kw, -1) + # Temporal pooling first (reduces tensor size before permute) + v = v.mean(dim=0) # (nh, kh, nw, kw, d) + # Spatial rearrangement: (nh, kh, nw, kw, d) -> (nh, nw, kh, kw, d) + out = v.permute(0, 2, 1, 3, 4).reshape(nh * nw, kh * kw, -1) + outputs.append(out) + + return outputs + + +class MoonViT3dPretrainedModel(nn.Module): + """Main vision tower model. + + Uses KimiK25VisionConfig directly from transformers_utils/configs/kimi_k25.py. + """ + + def __init__( + self, + config: KimiK25VisionConfig, + prefix: str = "", + ): + super().__init__() + config = deepcopy(config) + self.config = config # Required for run_dp_sharded_mrope_vision_model + self.merge_kernel_size = config.merge_kernel_size + self.patch_size = config.patch_size + self.merge_type = config.merge_type + + self.patch_embed = MoonVision3dPatchEmbed( + out_dim=config.hidden_size, + patch_size=config.patch_size, + pos_emb_height=config.init_pos_emb_height, + pos_emb_width=config.init_pos_emb_width, + pos_emb_time=config.init_pos_emb_time, + pos_emb_type=config.pos_emb_type, + ) + + self.encoder = MoonViT3dEncoder( + hidden_dim=config.hidden_size, + num_layers=config.num_hidden_layers, + block_cfg={ + "num_heads": config.num_attention_heads, + "hidden_dim": config.hidden_size, + "mlp_dim": config.intermediate_size, + "activation": get_act_fn("gelu_pytorch_tanh"), + "attn_bias": True, + }, + video_attn_type=config.video_attn_type, + prefix=maybe_prefix(prefix, "encoder"), + ) + + def forward( + self, pixel_values: torch.Tensor, grid_thws: torch.Tensor + ) -> torch.Tensor: + """ + Args: + pixel_values (torch.Tensor): The input pixel values. + grid_thws (torch.Tensor): Temporal, height and width. + + Returns: + torch.Tensor: The output tokens. + """ + hidden_states = self.patch_embed(pixel_values, grid_thws) + hidden_states = self.encoder(hidden_states, grid_thws) + if ( + self.merge_type == "sd2_tpool" + ): # spatial downsampling 2x with temporal pooling all + hidden_states = tpool_patch_merger( + hidden_states, grid_thws, merge_kernel_size=self.merge_kernel_size + ) + else: + raise NotImplementedError(f"Not support {self.merge_type}") + + return hidden_states + + +@torch.inference_mode() +def mm_projector_forward(mm_projector: torch.nn.Module, vt_output: list[torch.Tensor]): + """Apply MM projector to vision tower outputs.""" + num_embedding_list = [x.shape[0] for x in vt_output] + batched = torch.cat(vt_output, dim=0) + proj_out = mm_projector(batched) + proj_out = proj_out.reshape(-1, proj_out.shape[-1]) + proj_out = torch.split(proj_out, num_embedding_list) + return proj_out + + +@torch.inference_mode() +def vision_tower_forward( + vision_tower: Any, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + mm_projector: Any, + use_data_parallel: bool, +) -> list[torch.Tensor]: + """DP-sharded vision tower forward with mrope. + + Uses vLLM's standard data parallelism utility to shard the batch + across available GPUs, enabling parallel processing of vision features. + """ + if use_data_parallel: + grid_thw_list = grid_thw.tolist() + vt_outputs = run_dp_sharded_mrope_vision_model( + vision_model=vision_tower, + pixel_values=pixel_values, + grid_thw_list=grid_thw_list, + rope_type="rope_2d", + ) + else: + vt_outputs = vision_tower(pixel_values, grid_thw) + tensors = mm_projector_forward(mm_projector, list(vt_outputs)) + return list(tensors) + + +class KimiK25MultiModalProjector(nn.Module): + """Multi-modal projector with patch merging for Kimi-K2.5.""" + + def __init__( + self, + config: KimiK25VisionConfig, + use_data_parallel: bool = False, + prefix: str = "", + ): + super().__init__() + self.use_data_parallel = use_data_parallel + + # Hidden size after patch merging + merge_h, merge_w = config.merge_kernel_size + self.hidden_size = config.hidden_size * merge_h * merge_w + + self.pre_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-5) + self.linear_1 = ReplicatedLinear( + self.hidden_size, + self.hidden_size, + bias=True, + prefix=maybe_prefix(prefix, "linear_1"), + ) + self.linear_2 = ReplicatedLinear( + self.hidden_size, + config.mm_hidden_size, + bias=True, + prefix=maybe_prefix(prefix, "linear_2"), + ) + self.act = GELUActivation() + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size) + hidden_states, _ = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + return hidden_states diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 25b6e4025..6d70a8b5d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -359,6 +359,7 @@ _MULTIMODAL_MODELS = { ), "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 + "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501 "LightOnOCRForConditionalGeneration": ( "lightonocr", "LightOnOCRForConditionalGeneration", diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 7b1215876..8ce1e3587 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -20,6 +20,7 @@ from typing import ( ) import numpy as np +from PIL.Image import Image from typing_extensions import NotRequired, TypeVar from vllm.utils.collection_utils import full_groupby, is_list_of @@ -29,7 +30,6 @@ from vllm.utils.jsontree import json_map_leaves if TYPE_CHECKING: import torch import torch.types - from PIL.Image import Image from transformers.feature_extraction_utils import BatchFeature from .media import MediaWithBytes @@ -105,6 +105,28 @@ The number of data items allowed per modality is restricted by """ +class VisionChunkImage(TypedDict): + """Represents an image wrapped as a vision chunk.""" + + type: Literal["image"] + image: Image + uuid: str | None + + +class VisionChunkVideo(TypedDict): + """Represents a video chunk with metadata.""" + + type: Literal["video_chunk"] + video_chunk: list[Image] + uuid: str | None + prompt: str + video_idx: int + + +VisionChunk = VisionChunkImage | VisionChunkVideo +"""A vision chunk is either an image or a video chunk.""" + + @final class MultiModalDataBuiltins(TypedDict, total=False): """Type annotations for modality types predefined by vLLM.""" @@ -118,6 +140,9 @@ class MultiModalDataBuiltins(TypedDict, total=False): audio: ModalityData[AudioItem] """The input audio(s).""" + vision_chunk: ModalityData[VisionChunk] + """The input visual atom(s) - unified modality for images and video chunks.""" + MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]] """ diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index a8ebd427b..638478125 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -384,6 +384,13 @@ class VideoEmbeddingItems(EmbeddingItems): super().__init__(data, "video", expected_hidden_size) +class VisionChunkProcessorItems(ProcessorBatchItems[Any]): + """Processor items for vision chunks (unified image and video chunks).""" + + def __init__(self, data: Sequence[Any]) -> None: + super().__init__(data, "vision_chunk") + + _D = TypeVar("_D", bound=ModalityDataItems[Any, Any]) @@ -652,11 +659,23 @@ class MultiModalDataParser: return VideoProcessorItems(new_videos, metadata=metadata_lst) + def _parse_vision_chunk_data( + self, + data: ModalityData[Any], + ) -> ModalityDataItems[Any, Any] | None: + """Parse vision chunk data (unified image and video chunks).""" + if data is None or self._is_empty(data): + return None + if self.is_embeddings(data): + raise ValueError("Do not support embedding data for vision_chunk right now") + return VisionChunkProcessorItems(data) + def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: return { "audio": self._parse_audio_data, "image": self._parse_image_data, "video": self._parse_video_data, + "vision_chunk": self._parse_vision_chunk_data, } def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems: diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index f123799ca..9c7b9463b 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -235,6 +235,27 @@ class VideoLoader: VIDEO_LOADER_REGISTRY = ExtensionManager() +@VIDEO_LOADER_REGISTRY.register("identity") +class IdentityVideoLoader(VideoLoader): + """IdentityVideoLoader returns raw video bytes without decoding. + + This allows the model processor to handle video decoding and + is required for models like Kimi-K2.5 that need custom video chunk splitting. + + NOTE: This is temporary for Kimi-K2.5 testing. Remember to change back + to opencv before release if needed. + """ + + @classmethod + def load_bytes( + cls, + data: bytes, + num_frames: int = -1, + **kwargs: Any, + ) -> tuple[Any, Any]: + return data, None + + @VIDEO_LOADER_REGISTRY.register("opencv") class OpenCVVideoBackend(VideoLoader): def get_cv2_video_api(self): diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 7b918d2e3..05bc90e2e 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -53,8 +53,8 @@ _REASONING_PARSERS_TO_REGISTER = { "HunyuanA13BReasoningParser", ), "kimi_k2": ( - "deepseek_r1_reasoning_parser", - "DeepSeekR1ReasoningParser", + "kimi_k2_reasoning_parser", + "KimiK2ReasoningParser", ), "minimax_m2": ( "minimax_m2_reasoning_parser", diff --git a/vllm/reasoning/kimi_k2_reasoning_parser.py b/vllm/reasoning/kimi_k2_reasoning_parser.py new file mode 100644 index 000000000..42869585b --- /dev/null +++ b/vllm/reasoning/kimi_k2_reasoning_parser.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser +from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser + +from .identity_reasoning_parser import IdentityReasoningParser + +if TYPE_CHECKING: + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) +else: + ChatCompletionRequest = Any + + +logger = init_logger(__name__) + + +class KimiK2ReasoningParser(ReasoningParser): + """ + Kimi K2 parser that delegates to either DeepSeekR1ReasoningParser or + IdentityReasoningParser based on `thinking` and `separate_reasoning`. + + Unlike DeepSeekV3ReasoningParser which defaults to NOT thinking, + KimiK2ReasoningParser defaults to thinking mode (uses DeepSeekR1ReasoningParser). + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {} + # Key difference: default to True instead of False + thinking = bool(chat_kwargs.pop("thinking", True)) + + if thinking: + self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + else: + self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: list[int], delta_ids: list[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: "ChatCompletionRequest" + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/vllm/renderers/hf.py b/vllm/renderers/hf.py index d2252c655..e159a04b9 100644 --- a/vllm/renderers/hf.py +++ b/vllm/renderers/hf.py @@ -20,9 +20,11 @@ from vllm.entrypoints.chat_utils import ( ChatTemplateContentFormatOption, ChatTemplateResolutionError, ConversationMessage, + build_video_prompts_from_mm_data, load_chat_template, parse_chat_messages, parse_chat_messages_async, + rebuild_mm_uuids_from_mm_data, ) from vllm.inputs import TextPrompt, TokensPrompt from vllm.logger import init_logger @@ -547,6 +549,40 @@ class HfRenderer(RendererLike): **kwargs, ) + # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5 + # model which uses unified vision chunks for both images and videos. + if ( + getattr(model_config.hf_config, "use_unified_vision_chunk", False) + and mm_uuids is not None + and mm_data is not None + ): + mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data) + + # get video placehoder, replace it with runtime video-chunk prompts + video_placeholder = getattr( + model_config.hf_config, "video_placeholder", None + ) + if video_placeholder and isinstance(prompt_raw, str): + video_prompts = build_video_prompts_from_mm_data(mm_data) + + # replace in order + prompt_raw_parts = prompt_raw.split(video_placeholder) + if len(prompt_raw_parts) == len(video_prompts) + 1: + prompt_raw = "".join( + [ + prompt_raw_parts[i] + video_prompts[i] + for i in range(len(video_prompts)) + ] + ) + prompt_raw += prompt_raw_parts[-1] + else: + logger.warning( + "Number of video placeholders (%d) does not match " + "number of videos (%d) in the request.", + len(prompt_raw_parts) - 1, + len(video_prompts), + ) + prompt = ( TextPrompt(prompt=prompt_raw) if isinstance(prompt_raw, str) @@ -587,6 +623,40 @@ class HfRenderer(RendererLike): **kwargs, ) + # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5 + # model which uses unified vision chunks for both images and videos. + if ( + getattr(model_config.hf_config, "use_unified_vision_chunk", False) + and mm_uuids is not None + and mm_data is not None + ): + mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data) + + # get video placehoder, replace it with runtime video-chunk prompts + video_placeholder = getattr( + model_config.hf_config, "video_placeholder", None + ) + if video_placeholder and isinstance(prompt_raw, str): + video_prompts = build_video_prompts_from_mm_data(mm_data) + + # replace in order + prompt_raw_parts = prompt_raw.split(video_placeholder) + if len(prompt_raw_parts) == len(video_prompts) + 1: + prompt_raw = "".join( + [ + prompt_raw_parts[i] + video_prompts[i] + for i in range(len(video_prompts)) + ] + ) + prompt_raw += prompt_raw_parts[-1] + else: + logger.warning( + "Number of video placeholders (%d) does not match " + "number of videos (%d) in the request.", + len(prompt_raw_parts) - 1, + len(video_prompts), + ) + prompt = ( TextPrompt(prompt=prompt_raw) if isinstance(prompt_raw, str) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index a009017e5..277afa7dd 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -81,6 +81,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( isaac="IsaacConfig", kimi_linear="KimiLinearConfig", kimi_vl="KimiVLConfig", + kimi_k25="KimiK25Config", RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct) jais="JAISConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 00d5ecd25..bfb9c1758 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -38,6 +38,7 @@ _CLASS_TO_MODULE: dict[str, str] = { "MoonViTConfig": "vllm.transformers_utils.configs.moonvit", "KimiLinearConfig": "vllm.transformers_utils.configs.kimi_linear", "KimiVLConfig": "vllm.transformers_utils.configs.kimi_vl", + "KimiK25Config": "vllm.transformers_utils.configs.kimi_k25", "NemotronConfig": "vllm.transformers_utils.configs.nemotron", "NemotronHConfig": "vllm.transformers_utils.configs.nemotron_h", "Olmo3Config": "vllm.transformers_utils.configs.olmo3", @@ -77,6 +78,7 @@ __all__ = [ "MoonViTConfig", "KimiLinearConfig", "KimiVLConfig", + "KimiK25Config", "NemotronConfig", "NemotronHConfig", "Olmo3Config", diff --git a/vllm/transformers_utils/configs/kimi_k25.py b/vllm/transformers_utils/configs/kimi_k25.py new file mode 100644 index 000000000..72f67251d --- /dev/null +++ b/vllm/transformers_utils/configs/kimi_k25.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Kimi-K2.5 Model Configuration. + +This configuration supports video-chunk as an internal modality type. +A video-chunk is the smallest independently processable unit of video. +""" + +from transformers import DeepseekV3Config +from transformers.configuration_utils import PretrainedConfig + + +class KimiK25VisionConfig(PretrainedConfig): + model_type = "kimi_k25_vision" + + def __init__( + self, + # Vision Tower + patch_size: int = 14, + init_pos_emb_height: int = 64, + init_pos_emb_width: int = 64, + init_pos_emb_time: int = 4, + pos_emb_type: str = "divided_fixed", + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + hidden_size: int = 1152, + intermediate_size: int = 4304, + merge_kernel_size: tuple[int, int] = (2, 2), + video_attn_type: str = "spatial_temporal", + merge_type: str = "sd2_tpool", + # MM Projector + mm_projector_type: str = "patchmerger", + mm_hidden_size: int | None = None, + projector_hidden_act: str = "gelu", + projector_ln_eps: float = 1e-5, + **kwargs, + ): + super().__init__(**kwargs) + # Vision Tower + self.patch_size = patch_size + self.init_pos_emb_height = init_pos_emb_height + self.init_pos_emb_width = init_pos_emb_width + self.init_pos_emb_time = init_pos_emb_time + self.pos_emb_type = pos_emb_type + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.merge_kernel_size = merge_kernel_size + self.video_attn_type = video_attn_type + self.merge_type = merge_type + # MM Projector + self.mm_projector_type = mm_projector_type + if mm_hidden_size is not None: + self.mm_hidden_size = mm_hidden_size + else: + self.mm_hidden_size = hidden_size + self.projector_hidden_act = projector_hidden_act + self.projector_ln_eps = projector_ln_eps + + +class KimiK25Config(PretrainedConfig): + """Kimi-K2.5 model configuration. + + Kimi-K2.5 extends Kimi-K2 with vision support using video-chunks. + A video-chunk consists of multiple consecutive frames + that are processed together with temporal pooling. + + Args: + vision_config: Configuration for the vision tower and projector. + text_config: Configuration for the text model (DeepseekV3). + ignore_index: The ignore index for the loss function. + media_placeholder_token_id: The token ID for media placeholders. + pad_token_id: The token ID for padding. + """ + + model_type = "kimi_k25" + + def __init__( + self, + vision_config: dict | KimiK25VisionConfig | None = None, + text_config: dict | DeepseekV3Config | None = None, + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + use_unified_vision_chunk: bool = False, + video_placeholder: str = "<|kimi_k25_video_placeholder|>", + **kwargs, + ): + # Vision config + if vision_config is None: + vision_config = KimiK25VisionConfig() + elif isinstance(vision_config, dict): + vision_config = KimiK25VisionConfig(**vision_config) + self.vision_config: KimiK25VisionConfig = vision_config + + # Text config + if text_config is None: + text_config = DeepseekV3Config() + elif isinstance(text_config, dict): + text_config = DeepseekV3Config(**text_config) + self.text_config: DeepseekV3Config = text_config + + # Set mm_hidden_size to text hidden size if not explicitly set + if self.vision_config.mm_hidden_size == self.vision_config.hidden_size: + self.vision_config.mm_hidden_size = self.text_config.hidden_size + + # Other config + self.ignore_index = ignore_index + self.media_placeholder_token_id = media_placeholder_token_id + self.use_unified_vision_chunk = use_unified_vision_chunk + self.video_placeholder = video_placeholder + + # Propagate quantization config from text model + if getattr(self.text_config, "quantization_config", None) is not None: + self.quantization_config = self.text_config.quantization_config + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + @property + def hidden_size(self) -> int: + """Get hidden size from text config for compatibility.""" + return self.text_config.hidden_size + + @property + def vocab_size(self) -> int: + """Get vocab size from text config for compatibility.""" + return self.text_config.vocab_size