Signed-off-by: LoganJane <LoganJane73@hotmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
486 lines
17 KiB
Python
486 lines
17 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# ruff: noqa: E501
|
|
"""
|
|
Kimi-K2.5 Model Implementation for vLLM.
|
|
|
|
Kimi-K2.5 extends Kimi-K2 with vision support
|
|
|
|
This module defines:
|
|
- KimiK25ProcessingInfo/KimiK25MultiModalProcessor: Processing logic
|
|
- KimiK25ForConditionalGeneration: Main model class
|
|
"""
|
|
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from dataclasses import dataclass
|
|
from typing import Annotated, Any, Literal
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import BatchFeature
|
|
from transformers.processing_utils import ProcessorMixin
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
|
CompressedTensorsConfig,
|
|
)
|
|
from vllm.model_executor.models.interfaces import (
|
|
SupportsMultiModal,
|
|
SupportsPP,
|
|
SupportsQuant,
|
|
)
|
|
from vllm.model_executor.models.kimi_k25_vit import (
|
|
KimiK25MultiModalProjector,
|
|
MoonViT3dPretrainedModel,
|
|
vision_tower_forward,
|
|
)
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
NestedTensors,
|
|
VisionChunk,
|
|
VisionChunkImage,
|
|
VisionChunkVideo,
|
|
)
|
|
from vllm.multimodal.parse import MultiModalDataItems, VisionChunkProcessorItems
|
|
from vllm.multimodal.processing import (
|
|
BaseDummyInputsBuilder,
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
InputProcessingContext,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.transformers_utils.configs import KimiK25Config
|
|
from vllm.transformers_utils.processor import cached_get_image_processor
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
WeightsMapper,
|
|
init_vllm_registered_model,
|
|
maybe_prefix,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
# Dummy input dimensions for profiling.
|
|
@dataclass
|
|
class MaxImageTokenMeta:
|
|
width: int = 3000
|
|
height: int = 3000
|
|
|
|
|
|
class KimiK25MediaPixelInputs(TensorSchema):
|
|
"""
|
|
Media input schema for K2-VL model.
|
|
|
|
Dimensions:
|
|
- np: Number of patches (flattened from all media items)
|
|
- ps: Patch size
|
|
- nm: Number of media items
|
|
"""
|
|
|
|
type: Literal["pixel_values"] = "pixel_values"
|
|
|
|
pixel_values: Annotated[
|
|
torch.Tensor | list[torch.Tensor],
|
|
TensorShape("np", 3, "ps", "ps"),
|
|
]
|
|
|
|
grid_thws: Annotated[torch.Tensor, TensorShape("nm", 3)]
|
|
|
|
|
|
class MoonshotKimiVAutoProcessor(ProcessorMixin):
|
|
attributes = ["tokenizer"]
|
|
tokenizer_class = "AutoTokenizer"
|
|
|
|
def __init__(
|
|
self, media_processor=None, tokenizer=None, media_token_id: int | None = None
|
|
):
|
|
super().__init__(tokenizer)
|
|
self.media_processor = media_processor
|
|
self.media_token_id = media_token_id
|
|
assert self.media_token_id is not None
|
|
|
|
# We do not support str input for text here
|
|
def __call__(
|
|
self,
|
|
vision_chunks: list[VisionChunk] | None = None,
|
|
*,
|
|
text: list[int] | str,
|
|
**kwargs,
|
|
) -> BatchFeature:
|
|
"""
|
|
Args:
|
|
vision_chunks: List of VisionChunk items to be processed.
|
|
For image: VisionChunkImage with type='image', image=PIL.Image
|
|
For video_chunk: VisionChunkVideo with type='video_chunk', video_chunk=list[PIL.Image]
|
|
text: The token ids to be fed to a model (required).
|
|
Returns:
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
|
|
- **input_ids** -- list of token ids to be fed to a model.
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `vision_chunks` is not `None`.
|
|
- **grid_thws** -- list of image 3D grid in LLM. Returned when `vision_chunks` is not `None`.
|
|
"""
|
|
mm_inputs = {}
|
|
input_ids = self.tokenizer.encode(text) if isinstance(text, str) else text
|
|
if vision_chunks is not None:
|
|
assert isinstance(vision_chunks, list)
|
|
mm_inputs = self.media_processor.preprocess(vision_chunks)
|
|
|
|
num_tokens_per_chunk = [
|
|
self.media_processor.media_tokens_calculator(chunk)
|
|
for chunk in vision_chunks
|
|
]
|
|
|
|
new_input_ids = []
|
|
for token in input_ids:
|
|
if token == self.media_token_id:
|
|
new_input_ids.extend(
|
|
[self.media_token_id] * num_tokens_per_chunk.pop(0)
|
|
)
|
|
else:
|
|
new_input_ids.append(token)
|
|
input_ids = new_input_ids
|
|
|
|
# XXX: _apply_hf_processor_text_mm will call tolist() on input_ids
|
|
return BatchFeature(
|
|
data={
|
|
"input_ids": torch.tensor([input_ids]),
|
|
**mm_inputs,
|
|
}
|
|
)
|
|
|
|
|
|
class KimiK25ProcessingInfo(BaseProcessingInfo):
|
|
"""Processing information for Kimi-K2.5 model.
|
|
|
|
Provides configuration and utilities for processing both
|
|
images and video-chunks.
|
|
"""
|
|
|
|
def __init__(self, ctx: InputProcessingContext) -> None:
|
|
super().__init__(ctx)
|
|
self.hf_config = self.get_hf_config()
|
|
self.media_token_id = self.hf_config.media_placeholder_token_id
|
|
media_processor = cached_get_image_processor(
|
|
self.ctx.model_config.model, trust_remote_code=True
|
|
)
|
|
self.media_processor = media_processor
|
|
self.hf_processor = MoonshotKimiVAutoProcessor(
|
|
media_processor=self.media_processor,
|
|
tokenizer=self.get_tokenizer(),
|
|
media_token_id=self.media_token_id,
|
|
)
|
|
self.media_tokens_calculator = self.media_processor.media_tokens_calculator
|
|
|
|
def get_hf_processor(self):
|
|
return self.hf_processor
|
|
|
|
def get_hf_config(self):
|
|
return self.ctx.get_hf_config(KimiK25Config)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
# None means unlimited
|
|
return {"vision_chunk": None}
|
|
|
|
|
|
class KimiK25DummyInputsBuilder(BaseDummyInputsBuilder[KimiK25ProcessingInfo]):
|
|
"""Builds dummy inputs for Kimi-K2.5 model profiling."""
|
|
|
|
def __init__(self, info: KimiK25ProcessingInfo) -> None:
|
|
super().__init__(info)
|
|
self.media_token_id = self.info.media_token_id
|
|
self.frame_per_chunk = self.info.media_processor.num_frames_per_chunk
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_media = mm_counts.get("vision_chunk", 0)
|
|
return "<|media_pad|>" * num_media
|
|
|
|
def get_dummy_mm_items(self):
|
|
dummy_videos = self._get_dummy_images(
|
|
height=MaxImageTokenMeta.height,
|
|
width=MaxImageTokenMeta.width,
|
|
num_images=self.frame_per_chunk,
|
|
)
|
|
|
|
video_chunk_dummy_item = VisionChunkVideo(
|
|
type="video_chunk", video_chunk=dummy_videos
|
|
)
|
|
video_chunk_num_tokens = self.info.media_tokens_calculator(
|
|
video_chunk_dummy_item
|
|
)
|
|
|
|
image_dummy_item = VisionChunkImage(
|
|
type="image",
|
|
image=self._get_dummy_images(
|
|
height=MaxImageTokenMeta.height,
|
|
width=MaxImageTokenMeta.width,
|
|
num_images=1,
|
|
)[0],
|
|
)
|
|
image_num_tokens = self.info.media_tokens_calculator(image_dummy_item)
|
|
# return the larger one
|
|
if video_chunk_num_tokens >= image_num_tokens:
|
|
return [video_chunk_dummy_item]
|
|
else:
|
|
return [image_dummy_item]
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
mm_processor_kwargs: Mapping[str, object] | None = None,
|
|
) -> MultiModalDataDict:
|
|
# TODO: Support mm_options for vision_chunk to allow user configuration
|
|
dummy_items = self.get_dummy_mm_items()
|
|
return {"vision_chunk": dummy_items}
|
|
|
|
|
|
class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo]):
|
|
"""Multi-modal processor for Kimi-K2.5.
|
|
|
|
Handles both image and video-chunk modalities.
|
|
"""
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
"""Indicates how to slice media input into multiple items.
|
|
|
|
pixel_values: [N, 3, patch_size, patch_size], all patches collected from B medias
|
|
grid_thws: [B,3], each item: [N_t, N_h ,N_w], indicates the grid size in time/height/width direction
|
|
for current item.
|
|
|
|
by multiplying [N_t, N_h ,N_w], we get the number of patches for each media item, thus we can slice
|
|
pixel_values by pixel_values[start:start + N_t*N_h*N_w] to get patches of one item.
|
|
|
|
"""
|
|
grid_thws = hf_inputs.get("grid_thws", torch.empty((0, 3)))
|
|
grid_sizes = grid_thws.prod(-1)
|
|
|
|
return dict(
|
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
|
"vision_chunk", grid_sizes
|
|
),
|
|
grid_thws=MultiModalFieldConfig.batched("vision_chunk"),
|
|
)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
hf_config = self.info.get_hf_config()
|
|
media_token_id = hf_config.media_placeholder_token_id
|
|
|
|
def get_replacement(item_idx: int):
|
|
media = mm_items.get_items("vision_chunk", (VisionChunkProcessorItems,))
|
|
num_media_token = self.info.media_tokens_calculator(media[item_idx])
|
|
return [media_token_id] * num_media_token
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="vision_chunk",
|
|
target=[media_token_id],
|
|
replacement=get_replacement,
|
|
),
|
|
]
|
|
|
|
def split_video_chunks(self, video):
|
|
return self.info.media_processor.split_video_chunks(video)
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
KimiK25MultiModalProcessor,
|
|
info=KimiK25ProcessingInfo,
|
|
dummy_inputs=KimiK25DummyInputsBuilder,
|
|
)
|
|
class KimiK25ForConditionalGeneration(
|
|
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
|
|
):
|
|
"""Kimi-K2.5 model for conditional generation.
|
|
|
|
Supports both image and video-chunk modalities.
|
|
Video-chunks are temporal segments (typically 4 frames) that are
|
|
processed with temporal pooling.
|
|
"""
|
|
|
|
supports_encoder_tp_data = True
|
|
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_prefix={
|
|
# For legacy NVFP4 checkpoint compatibility:
|
|
# see https://github.com/vllm-project/vllm/pull/33346#issuecomment-3851475033
|
|
"language_model.layers.": "language_model.model.layers.",
|
|
# mm projector
|
|
"mm_projector.proj.0": "mm_projector.linear_1",
|
|
"mm_projector.proj.2": "mm_projector.linear_2",
|
|
}
|
|
)
|
|
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
# Kimi-K2.5 uses video_chunk for all media types
|
|
if modality == "image":
|
|
return "<|media_begin|>image<|media_content|><|media_pad|><|media_end|>"
|
|
elif modality == "video":
|
|
# return a placeholder, to be replaced in the future.
|
|
return "<|kimi_k25_video_placeholder|>"
|
|
|
|
raise ValueError(f"Unsupported modality: {modality}")
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
model_config = vllm_config.model_config
|
|
config: KimiK25Config = model_config.hf_config
|
|
self.config = config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
# Check for MoonViT config compatibility
|
|
self.use_data_parallel = (
|
|
model_config.multimodal_config.mm_encoder_tp_mode == "data"
|
|
)
|
|
self.hidden_size = config.text_config.hidden_size
|
|
self.device = current_platform.current_device()
|
|
# Build vision tower directly with KimiK25VisionConfig
|
|
with self._mark_tower_model(vllm_config, "vision_chunk"):
|
|
self.vision_tower = MoonViT3dPretrainedModel(
|
|
config.vision_config,
|
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
|
prefix=maybe_prefix(prefix, "vision_tower"),
|
|
)
|
|
self.vision_tower = self.vision_tower.to(
|
|
device=self.device, dtype=model_config.dtype
|
|
)
|
|
|
|
self.mm_projector = KimiK25MultiModalProjector(
|
|
config=config.vision_config,
|
|
use_data_parallel=self.use_data_parallel,
|
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
|
prefix=maybe_prefix(prefix, "mm_projector"),
|
|
)
|
|
self.mm_projector = self.mm_projector.to(
|
|
device=self.device, dtype=model_config.dtype
|
|
)
|
|
|
|
self.quant_config = quant_config
|
|
with self._mark_language_model(vllm_config):
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
architectures=["DeepseekV2ForCausalLM"],
|
|
)
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors
|
|
)
|
|
self.media_placeholder: int = self.config.media_placeholder_token_id
|
|
|
|
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
|
if isinstance(quant_config, CompressedTensorsConfig):
|
|
return None
|
|
return quant_config
|
|
|
|
def _parse_and_validate_media_input(
|
|
self, **kwargs: object
|
|
) -> KimiK25MediaPixelInputs | None:
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
grid_thws = kwargs.pop("grid_thws", None)
|
|
if pixel_values is None:
|
|
return None
|
|
|
|
if isinstance(pixel_values, list):
|
|
pixel_values = torch.cat(pixel_values, dim=0)
|
|
|
|
if len(pixel_values.shape) == 5 or len(pixel_values.shape) == 3:
|
|
pixel_values = pixel_values.reshape(
|
|
pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
|
|
)
|
|
|
|
# The batch dimension of pixel_values has been flattened into shape[0]
|
|
target_dtype = next(self.vision_tower.parameters()).dtype
|
|
pixel_values = pixel_values.to(target_dtype)
|
|
assert isinstance(grid_thws, torch.Tensor), (
|
|
f"expect grid_thws to be a tensor, get {type(grid_thws)}"
|
|
)
|
|
# In some cases (e.g. with merger), grid_thws has an extra middle dimension
|
|
grid_thws = grid_thws.reshape(-1, grid_thws.shape[-1])
|
|
assert grid_thws.ndim == 2 and grid_thws.size(1) == 3, (
|
|
f"unexpected shape for grid_thws: {grid_thws.shape}"
|
|
)
|
|
|
|
return KimiK25MediaPixelInputs(
|
|
type="pixel_values",
|
|
pixel_values=pixel_values,
|
|
grid_thws=grid_thws,
|
|
)
|
|
|
|
def _process_media_input(
|
|
self, media_input: KimiK25MediaPixelInputs
|
|
) -> list[torch.Tensor]:
|
|
# NOTE(moyan): This forward will automatically batch the forward pass internally
|
|
media_features = vision_tower_forward(
|
|
self.vision_tower,
|
|
media_input["pixel_values"],
|
|
media_input["grid_thws"],
|
|
mm_projector=self.mm_projector,
|
|
use_data_parallel=self.use_data_parallel,
|
|
)
|
|
return media_features
|
|
|
|
def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
|
|
# Validate the multimodal input keyword arguments
|
|
media_input = self._parse_and_validate_media_input(**kwargs)
|
|
if media_input is None:
|
|
return None
|
|
|
|
# Run multimodal inputs through encoder and projector
|
|
vision_embeddings = self._process_media_input(media_input)
|
|
return vision_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> IntermediateTensors:
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
hidden_states = self.language_model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
logits = self.language_model.compute_logits(hidden_states)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|