diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index c22f2ab3d..dd442d9e3 100755 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -995,6 +995,31 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# Kimi-VL +def run_kimi_k25(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "vision_chunk" + + prompts = [ + "<|im_user|>user<|media_begin|>image<|media_content|>" + f"<|media_pad|><|media_end|>{question}<|im_end|>" + "<|im_assistant|>assistant<|im_middle|>" + for question in questions + ] + + engine_args = EngineArgs( + model="moonshotai/Kimi-K2.5", + trust_remote_code=True, + max_model_len=4096, + limit_mm_per_prompt={modality: 1}, + tensor_parallel_size=4, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # LightOnOCR def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -2110,6 +2135,7 @@ model_example_map = { "keye_vl": run_keye_vl, "keye_vl1_5": run_keye_vl1_5, "kimi_vl": run_kimi_vl, + "kimi_k25": run_kimi_k25, "lightonocr": run_lightonocr, "lfm2_vl": run_lfm2_vl, "llama4": run_llama4, @@ -2196,6 +2222,19 @@ def get_multi_modal_input(args): "questions": vid_questions, } + if args.modality == "vision_chunk": + # Input vision chunks and question + image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") + vision_chunk_questions = [ + "What is the content of this image chunk?", + "Describe the content of this image chunk in detail.", + ] + + return { + "data": {"type": "image", "image": image}, + "questions": vision_chunk_questions, + } + msg = f"Modality {args.modality} is not supported." raise ValueError(msg) @@ -2278,7 +2317,7 @@ def parse_args(): "--modality", type=str, default="image", - choices=["image", "video"], + choices=["image", "video", "vision_chunk"], help="Modality of the input.", ) parser.add_argument( @@ -2355,7 +2394,7 @@ def main(args): req_data = model_example_map[model](questions, modality) # Disable other modalities to save memory - default_limits = {"image": 0, "video": 0, "audio": 0} + default_limits = {"image": 0, "video": 0, "audio": 0, "vision_chunk": 0} req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt or {} ) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index bb1c8b478..b228898ff 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -214,6 +214,28 @@ def get_text_token_prompts( return text_prompt, token_prompt +def random_vision_chunk( + rng: np.random.RandomState, + min_wh: int, + max_wh: int, + min_frames: int, + max_frames: int, +) -> dict: + num_frames = rng.randint(min_frames, max_frames + 1) + if num_frames == 1: + # Single image chunk + wh = rng.randint(min_wh, max_wh + 1) + image = random_image(rng, wh, wh + 1) + return {"type": "image", "image": image} + frames = [] + for _ in range(num_frames): + wh = rng.randint(min_wh, max_wh + 1) + frame = rng.randint(0, 256, size=(wh, wh, 3), dtype=np.uint8) + frames.append(frame) + video_array = np.stack(frames, axis=0) + return {"type": "video_chunk", "video_chunk": video_array} + + def _test_processing_correctness( model_id_or_arch: str, hit_rate: float, @@ -291,6 +313,7 @@ def _test_processing_correctness( "image": Image.new("RGB", size=(128, 128)), "video": np.zeros((4, 128, 128, 3), dtype=np.uint8), "audio": (np.zeros((512,)), 16000), + "vision_chunk": {"type": "image", "image": Image.new("RGB", size=(128, 128))}, } input_factory = { "image": partial(random_image, rng, min_wh=128, max_wh=256), @@ -298,6 +321,9 @@ def _test_processing_correctness( random_video, rng, min_frames=2, max_frames=16, min_wh=128, max_wh=256 ), "audio": partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), + "vision_chunk": partial( + random_vision_chunk, rng, min_wh=128, max_wh=256, min_frames=1, max_frames=1 + ), } for batch_idx in range(num_batches): @@ -413,11 +439,6 @@ def test_processing_correctness( "Qwen-VL tokenizer requires downloading a font file from " "servers that often refuse connections in CI" ) - if model_id == "moonshotai/Kimi-K2.5": - # FIXME(Isaac): Fix Kimi-K2.5's offline inference about vision chunks. - pytest.skip( - "Kimi-K2.5's offline inference has issues about vision chunks. Fix later." - ) _test_processing_correctness( model_id, diff --git a/vllm/model_executor/models/kimi_k25.py b/vllm/model_executor/models/kimi_k25.py index 4be79ca95..191aed8e5 100644 --- a/vllm/model_executor/models/kimi_k25.py +++ b/vllm/model_executor/models/kimi_k25.py @@ -96,16 +96,20 @@ class MoonshotKimiVAutoProcessor(ProcessorMixin): attributes = ["tokenizer"] tokenizer_class = "AutoTokenizer" - def __init__(self, media_processor=None, tokenizer=None): + def __init__( + self, media_processor=None, tokenizer=None, media_token_id: int | None = None + ): super().__init__(tokenizer) self.media_processor = media_processor + self.media_token_id = media_token_id + assert self.media_token_id is not None # We do not support str input for text here def __call__( self, vision_chunks: list[VisionChunk] | None = None, *, - text: list[int], + text: list[int] | str, **kwargs, ) -> BatchFeature: """ @@ -122,13 +126,30 @@ class MoonshotKimiVAutoProcessor(ProcessorMixin): - **grid_thws** -- list of image 3D grid in LLM. Returned when `vision_chunks` is not `None`. """ mm_inputs = {} + input_ids = self.tokenizer.encode(text) if isinstance(text, str) else text if vision_chunks is not None: assert isinstance(vision_chunks, list) mm_inputs = self.media_processor.preprocess(vision_chunks) + + num_tokens_per_chunk = [ + self.media_processor.media_tokens_calculator(chunk) + for chunk in vision_chunks + ] + + new_input_ids = [] + for token in input_ids: + if token == self.media_token_id: + new_input_ids.extend( + [self.media_token_id] * num_tokens_per_chunk.pop(0) + ) + else: + new_input_ids.append(token) + input_ids = new_input_ids + # XXX: _apply_hf_processor_text_mm will call tolist() on input_ids return BatchFeature( data={ - "input_ids": torch.tensor([text]), + "input_ids": torch.tensor([input_ids]), **mm_inputs, } ) @@ -152,6 +173,7 @@ class KimiK25ProcessingInfo(BaseProcessingInfo): self.hf_processor = MoonshotKimiVAutoProcessor( media_processor=self.media_processor, tokenizer=self.get_tokenizer(), + media_token_id=self.media_token_id, ) self.media_tokens_calculator = self.media_processor.media_tokens_calculator @@ -174,9 +196,9 @@ class KimiK25DummyInputsBuilder(BaseDummyInputsBuilder[KimiK25ProcessingInfo]): self.media_token_id = self.info.media_token_id self.frame_per_chunk = self.info.media_processor.num_frames_per_chunk - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]: + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_media = mm_counts.get("vision_chunk", 0) - return [self.media_token_id] * num_media + return "<|media_pad|>" * num_media def get_dummy_mm_items(self): dummy_videos = self._get_dummy_images( diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 638478125..0462ab5de 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -668,6 +668,8 @@ class MultiModalDataParser: return None if self.is_embeddings(data): raise ValueError("Do not support embedding data for vision_chunk right now") + if isinstance(data, dict): + data = [data] return VisionChunkProcessorItems(data) def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: