diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 181d684b8..45465d9c4 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -672,6 +672,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | | `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I+ | `deepseek-ai/DeepSeek-OCR`, etc. | ✅︎ | ✅︎ | +| `DeepseekOCR2ForCausalLM` | DeepSeek-OCR-2 | T + I+ | `deepseek-ai/DeepSeek-OCR-2`, etc. | ✅︎ | ✅︎ | | `Eagle2_5_VLForConditionalGeneration` | Eagle2.5-VL | T + IE+ | `nvidia/Eagle2.5-8B`, etc. | ✅︎ | ✅︎ | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index c378e6696..c22f2ab3d 100755 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -270,6 +270,49 @@ def run_deepseek_ocr(questions: list[str], modality: str) -> ModelRequestData: ) +def run_deepseek_ocr2(questions: list[str], modality: str) -> ModelRequestData: + from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor + + assert modality == "image" + + model_name = "deepseek-ai/DeepSeek-OCR-2" + + engine_args = EngineArgs( + model=model_name, + limit_mm_per_prompt={modality: 1}, + logits_processors=[NGramPerReqLogitsProcessor], + ) + + # deepseek-ocr use plain prompt template + prompts = [f"\n{question}" for question in questions] + + # The following sampling params config is taken from + # the official Deepseek-OCR inference example. + # (IMPORTANT) Use the custom logits processor and avoid skipping + # special tokens for this model for the optimal OCR performance. + sampling_params = [ + SamplingParams( + temperature=0.0, + max_tokens=8192, + # ngram logit processor args + extra_args=dict( + ngram_size=30, + window_size=90, + # whitelist: , + whitelist_token_ids={128821, 128822}, + ), + skip_special_tokens=False, + ) + for _ in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + sampling_params=sampling_params, + ) + + # Dots-OCR def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -2045,6 +2088,7 @@ model_example_map = { "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, "deepseek_ocr": run_deepseek_ocr, + "deepseek_ocr2": run_deepseek_ocr2, "dots_ocr": run_dots_ocr, "eagle2_5": run_eagle2_5, "ernie45_vl": run_ernie45_vl, diff --git a/tests/models/registry.py b/tests/models/registry.py index 1bee16c81..0e3d0d312 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -687,6 +687,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { "DeepseekOCRForCausalLM": _HfExamplesInfo( "deepseek-ai/DeepSeek-OCR", ), + "DeepseekOCR2ForCausalLM": _HfExamplesInfo( + "deepseek-ai/DeepSeek-OCR-2", + ), "DotsOCRForCausalLM": _HfExamplesInfo( "rednote-hilab/dots.ocr", trust_remote_code=True ), diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py index 651ced896..f7ae4264f 100644 --- a/vllm/model_executor/models/deepencoder.py +++ b/vllm/model_executor/models/deepencoder.py @@ -79,6 +79,7 @@ class ImageEncoderViT(nn.Module): rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: tuple[int, ...] = (), + last_conv_output: int = 1024, ) -> None: """ Args: @@ -155,7 +156,7 @@ class ImageEncoderViT(nn.Module): 256, 512, kernel_size=3, stride=2, padding=1, bias=False ) self.net_3 = Conv2dLayer( - 512, 1024, kernel_size=3, stride=2, padding=1, bias=False + 512, last_conv_output, kernel_size=3, stride=2, padding=1, bias=False ) def get_abs_pos(self, abs_pos: torch.Tensor, tgt_size: int): diff --git a/vllm/model_executor/models/deepencoder2.py b/vllm/model_executor/models/deepencoder2.py new file mode 100644 index 000000000..b50606d47 --- /dev/null +++ b/vllm/model_executor/models/deepencoder2.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from +# https://github.com/deepseek-ai/DeepSeek-OCR-2/blob/main/DeepSeek-OCR2-master/DeepSeek-OCR2-vllm/deepencoderv2/qwen2_d2e.py + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import transformers + + +class CustomQwen2Decoder(nn.Module): + """ + Qwen2 visual encoder + non-causal attention + causal attention + token_type_ids :0=non-causal, 1=causal + """ + + def __init__( + self, + decoder_layer: int = 24, + max_position_embeddings: int = 131072, + hidden_dimension: int = 896, + num_attention_heads: int = 14, + num_key_value_heads: int = 2, + intermediate_size: int = 4864, + vocab_size: int = 151936, + attn_implementation: str = "sdpa", # ⭐ + rms_norm_eps: float = 1e-06, + rope_theta: float = 1000000.0, + attention_dropout: float = 0.0, + hidden_act: str = "silu", + initializer_range: float = 0.02, + ): + super().__init__() + + # load + Qwen2Model = transformers.models.qwen2.modeling_qwen2.Qwen2Model + Qwen2Config = transformers.Qwen2Config + + # config + config = Qwen2Config( + hidden_size=hidden_dimension, + num_hidden_layers=decoder_layer, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + vocab_size=vocab_size, + rms_norm_eps=rms_norm_eps, + rope_theta=rope_theta, + attention_dropout=attention_dropout, + hidden_act=hidden_act, + initializer_range=initializer_range, + _attn_implementation=attn_implementation, # ⭐ + ) + + # + self.model = self._create_custom_model(Qwen2Model, config) + + del self.model.embed_tokens + + def _create_custom_model(self, Qwen2Model, config): + """Qwen2Model""" + + class CustomQwen2ModelInner(Qwen2Model): + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + token_type_ids=None, # ⭐ + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + # token_type_ids + self._current_token_type_ids = token_type_ids + causal_mask_mapping = { + "full_attention": self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + } + outputs = super().forward( + input_ids=input_ids, + attention_mask=causal_mask_mapping, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + return outputs + + def _update_causal_mask( + self, + attention_mask, + input_tensor, + cache_position, + past_key_values, + output_attentions, + ): + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + batch_size, sequence_length = ( + input_tensor.shape[0], + input_tensor.shape[1], + ) + + token_type_ids = self._current_token_type_ids + + # attention mask + causal_mask = self._create_custom_4d_mask( + sequence_length=sequence_length, + dtype=dtype, + device=device, + batch_size=batch_size, + token_type_ids=token_type_ids, + ) + + # padding mask + if attention_mask is not None and attention_mask.dim() == 2: + padding_mask = attention_mask[:, None, None, :].to(dtype=dtype) + padding_mask = (1.0 - padding_mask) * min_dtype + causal_mask = causal_mask + padding_mask + + return causal_mask + + def _create_custom_4d_mask( + self, + sequence_length, + dtype, + device, + batch_size, + token_type_ids, + ): + min_dtype = torch.finfo(dtype).min + + masks = [] + for b in range(batch_size): + mask = torch.full( + (sequence_length, sequence_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + + type_ids = token_type_ids[b] + + image_positions = (type_ids == 0).nonzero(as_tuple=True)[0] + text_positions = (type_ids == 1).nonzero(as_tuple=True)[0] + + # non-casual + if len(image_positions) > 0: + mask[image_positions[:, None], image_positions] = 0.0 + + # causal + for i, text_pos in enumerate(text_positions): + if len(image_positions) > 0: + mask[text_pos, image_positions] = 0.0 + mask[text_pos, text_positions[: i + 1]] = 0.0 + + masks.append(mask) + + mask = torch.stack(masks, dim=0).unsqueeze(1) + return mask + + return CustomQwen2ModelInner(config) + + def forward( + self, + inputs_embeds: torch.Tensor, + token_type_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + **kwargs, + ): + """ + Args: + inputs_embeds: [batch_size, seq_len, hidden_dim] + token_type_ids: [batch_size, seq_len], 0=non-causal, 1=causal + attention_mask: [batch_size, seq_len], optional + """ + return self.model( + inputs_embeds=inputs_embeds, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + **kwargs, + ) + + +class Qwen2Decoder2Encoder(nn.Module): + """ + Decoder based on Multilingual BART + Set the initial weights and configuration with a pretrained multilingual BART model, + and modify the detailed configurations as a Nougat decoder + """ + + def __init__( + self, + decoder_layer: int, + hidden_dimension: int, + num_attention_heads: int, + num_key_value_heads: int, + intermediate_size: int, + ): + super().__init__() + + self.model = CustomQwen2Decoder( + decoder_layer=decoder_layer, + hidden_dimension=hidden_dimension, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + intermediate_size=intermediate_size, + attn_implementation="sdpa", + ) + self.query_768 = nn.Embedding(144, hidden_dimension) + self.query_1024 = nn.Embedding(256, hidden_dimension) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.flatten(2).transpose(1, 2) + + bs, n_query, _ = x.shape + + if n_query == 144: + param_img = self.query_768.weight + elif n_query == 256: + param_img = self.query_1024.weight + + batch_query_imgs = param_img.unsqueeze(0).expand( + bs, -1, -1 + ) # (batch_size, num_queries, hidden_size) + + x_combined = torch.cat([x, batch_query_imgs], dim=1) + + token_type_ids = torch.cat( + [ + torch.zeros(bs, n_query, dtype=torch.long), + torch.ones(bs, n_query, dtype=torch.long), + ], + dim=1, + ) + + y = self.model(x_combined, token_type_ids)[0] + + y = y[:, n_query:, :] # causal flow query + + return y + + +def build_qwen2_decoder_as_encoder( + decoder_layer=24, + hidden_dimension=896, + num_attention_heads=14, + num_key_value_heads=2, + intermediate_size=4864, +): + decoder_as_encoder = Qwen2Decoder2Encoder( + decoder_layer=decoder_layer, + hidden_dimension=hidden_dimension, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + intermediate_size=intermediate_size, + ) + + return decoder_as_encoder diff --git a/vllm/model_executor/models/deepseek_ocr2.py b/vllm/model_executor/models/deepseek_ocr2.py new file mode 100644 index 000000000..6541edad2 --- /dev/null +++ b/vllm/model_executor/models/deepseek_ocr2.py @@ -0,0 +1,444 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Deepseek-OCR model compatible with HuggingFace weights.""" + +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import partial + +import torch +import torch.nn as nn +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) +from vllm.sequence import IntermediateTensors +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config +from vllm.transformers_utils.processors.deepseek_ocr2 import ( + BASE_SIZE, + CROP_MODE, + IMAGE_SIZE, + DeepseekOCR2Processor, +) + +from ...transformers_utils.processors.deepseek_ocr import count_tiles +from .deepencoder import ImageEncoderViT +from .deepencoder2 import build_qwen2_decoder_as_encoder +from .deepseek_ocr import DeepseekOCRImagePixelInputs +from .deepseek_vl2 import MlpProjector + +# The image token id may be various +_IMAGE_TOKEN = "" + + +class DeepseekOCR2ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(DeepseekVLV2Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(DeepseekOCR2Processor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_num_image_tokens( + self, *, image_width: int, image_height: int, cropping: bool = True + ) -> int: + image_size = IMAGE_SIZE + base_size = BASE_SIZE + patch_size = 16 + downsample_ratio = 4 + + if CROP_MODE: + if image_width <= 768 and image_height <= 768: + crop_ratio = [1, 1] + else: + # find the closest aspect ratio to the target + crop_ratio = count_tiles( + image_width, image_height, image_size=IMAGE_SIZE + ) + + num_width_tiles, num_height_tiles = crop_ratio + else: + num_width_tiles = num_height_tiles = 1 + + h = w = math.ceil((base_size // patch_size) / downsample_ratio) + + h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio) + + global_views_tokens = h * w + if num_width_tiles > 1 or num_height_tiles > 1: + local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2) + else: + local_views_tokens = 0 + + return global_views_tokens + local_views_tokens + 1 + + def get_image_size_with_most_features(self) -> ImageSize: + if IMAGE_SIZE == 1024 and BASE_SIZE == 1280: + return ImageSize(width=1024 * 2, height=1024 * 2) + return ImageSize(width=768 * 2, height=768 * 2) + + +class DeepseekOCR2DummyInputsBuilder( + BaseDummyInputsBuilder[DeepseekOCR2ProcessingInfo] +): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + max_image_size = self.info.get_image_size_with_most_features() + + return { + "image": self._get_dummy_images( + width=max_image_size.width, + height=max_image_size.height, + num_images=num_images, + ) + } + + +class DeepseekOCR2MultiModalProcessor( + BaseMultiModalProcessor[DeepseekOCR2ProcessingInfo] +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(prompt=prompt, **mm_data), + mm_kwargs, + ) + + else: + tokenizer = self.info.get_tokenizer() + processed_outputs = tokenizer( + prompt, add_special_tokens=True, return_tensors="pt" + ) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2))) + is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1) + patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0) + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + images_spatial_crop=MultiModalFieldConfig.batched("image"), + images_crop=MultiModalFieldConfig.flat_from_sizes( + "image", patches_per_image + ), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + image_token_id = hf_processor.image_token_id + assert isinstance(image_token_id, int) + + def get_replacement_deepseek_vl2(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=size.width, + image_height=size.height, + cropping=CROP_MODE, + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_deepseek_vl2, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + DeepseekOCR2MultiModalProcessor, + info=DeepseekOCR2ProcessingInfo, + dummy_inputs=DeepseekOCR2DummyInputsBuilder, +) +class DeepseekOCR2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # map prefix for language backbone + "model.embed_tokens.": "language_model.model.embed_tokens.", + "model.layers.": "language_model.model.layers.", + "model.norm.": "language_model.model.norm.", + "lm_head.": "language_model.lm_head.", + # remove "model." prefix for other components + "model.": "", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: DeepseekVLV2Config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_config = config.vision_config + self.projector_config = config.projector_config + self.text_config = config.text_config + model_config = vllm_config.model_config + tokenizer = cached_tokenizer_from_config(model_config) + self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] + + with self._mark_tower_model(vllm_config, "image"): + self.sam_model = ImageEncoderViT( + depth=12, + embed_dim=768, + img_size=1024, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=12, + patch_size=16, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=[2, 5, 8, 11], + window_size=14, + out_chans=256, + last_conv_output=896, + ) + self.qwen2_model = build_qwen2_decoder_as_encoder() + + self.projector = MlpProjector(self.projector_config) + self.tile_tag = config.tile_tag + self.global_view_pos = config.global_view_pos + + # special token for image token sequence format + n_embed = self.projector_config.n_embed + embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) + if self.tile_tag == "2D": + # This is a typo in original implementation + self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) + else: + raise ValueError( + f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" + ) + + with self._mark_language_model(vllm_config): + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=self.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> DeepseekOCRImagePixelInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + images_spatial_crop = kwargs.pop("images_spatial_crop", None) + images_crop = kwargs.pop("images_crop", None) + + if pixel_values is None or torch.sum(pixel_values).item() == 0: + return None + + base_size = self.vision_config.image_size + return DeepseekOCRImagePixelInputs( + type="pixel_values", + data=pixel_values, + images_crop=images_crop, + images_spatial_crop=images_spatial_crop, + resolve_bindings={ + "base_size": base_size, + }, + ) + + def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor: + global_features_1 = self.sam_model(image_tensor) + global_features_2 = self.qwen2_model(global_features_1) + + features = self.projector(global_features_2) + + _, hw, dim = features.shape + + return features.view(-1, dim) + + def _encode_local_features(self, patches: torch.Tensor) -> torch.Tensor | None: + if torch.sum(patches).item() == 0: + return None + + local_features = self.sam_model(patches) + local_features = self.qwen2_model(local_features) + + features = self.projector(local_features) + + _, _, dim = features.shape + + return features.view(-1, dim) + + def _pixel_values_to_embedding( + self, + pixel_values: torch.Tensor, + images_crop: torch.Tensor, + images_spatial_crop: torch.Tensor, + ) -> NestedTensors: + images_in_this_batch = [] + + is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1) + patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0) + images_crop = images_crop.split(patches_per_image.tolist()) + for jdx in range(images_spatial_crop.size(0)): + patches = images_crop[jdx] + image_ori = pixel_values[[jdx]] + + global_features = self._encode_global_features(image_ori) + local_features = self._encode_local_features(patches) + + if local_features is not None: + combined = torch.cat( + [local_features, global_features, self.view_seperator[None, :]], + dim=0, + ) + else: + combined = torch.cat( + [global_features, self.view_seperator[None, :]], dim=0 + ) + + images_in_this_batch.append(combined) + + return images_in_this_batch + + def _process_image_input( + self, image_input: DeepseekOCRImagePixelInputs + ) -> torch.Tensor: + pixel_values = image_input.data + images_crop = image_input.images_crop + images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long) + + vision_features = self._pixel_values_to_embedding( + pixel_values=pixel_values, + images_crop=images_crop, + images_spatial_crop=images_spatial_crop, + ) + + return vision_features + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ): + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + return autoloaded_weights + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="projector", + tower_model=["sam_model", "qwen2_model"], + ) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 95f6cc065..ed2a39d24 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -308,6 +308,7 @@ _MULTIMODAL_MODELS = { ), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"), + "DeepseekOCR2ForCausalLM": ("deepseek_ocr2", "DeepseekOCR2ForCausalLM"), "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"), "Eagle2_5_VLForConditionalGeneration": ( "eagle2_5_vl", diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index fe84b6c15..0064cc6d6 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -34,6 +34,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja", + "deepseek_ocr2": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja", "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", "minicpmv": _get_minicpmv_chat_template_fallback, diff --git a/vllm/transformers_utils/processors/deepseek_ocr2.py b/vllm/transformers_utils/processors/deepseek_ocr2.py new file mode 100644 index 000000000..6dbda73d4 --- /dev/null +++ b/vllm/transformers_utils/processors/deepseek_ocr2.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# adapted from https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/image_process.py +import math + +import torch +from PIL import Image, ImageOps +from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast +from transformers.processing_utils import ProcessorMixin + +from vllm.transformers_utils.processors.deepseek_ocr import ( + ImageTransform, + dynamic_preprocess, +) + +BASE_SIZE = 1024 +IMAGE_SIZE = 768 +CROP_MODE = True +MIN_CROPS = 2 +MAX_CROPS = 6 + + +class DeepseekOCR2Processor(ProcessorMixin): + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + attributes = ["tokenizer"] + + def __init__( + self, + tokenizer: LlamaTokenizerFast, + patch_size: int = 16, + downsample_ratio: int = 4, + image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + image_token: str = "", + pad_token: str = "<|▁pad▁|>", + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + self.image_size = IMAGE_SIZE + self.base_size = BASE_SIZE + self.patch_size = 16 + self.image_mean = image_mean + self.image_std = image_std + self.normalize = normalize + self.downsample_ratio = 4 + + self.image_transform = ImageTransform( + mean=image_mean, std=image_std, normalize=normalize + ) + + self.tokenizer = tokenizer + self.tokenizer.padding_side = "left" # must set this,padding side with make a difference in batch inference # noqa: E501 + + # add the pad_token as special token to use 'tokenizer.pad_token' + # and 'tokenizer.pad_token_id' + if self.tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({"pad_token": pad_token}) + + # add image token + self.image_token_id = self.tokenizer.vocab.get(image_token) + self.image_token = image_token + self.pad_token = pad_token + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + + super().__init__( + tokenizer, + **kwargs, + ) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def pad_id(self): + return self.tokenizer.pad_token_id + + def encode(self, text: str, bos: bool = True, eos: bool = False): + t = self.tokenizer.encode(text, add_special_tokens=False) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: list[int], **kwargs) -> str: + return self.tokenizer.decode(t, **kwargs) + + def process_one( + self, + prompt: str, + images: list[Image.Image], + crop_mode: bool = CROP_MODE, + ): + """ + + Args: + prompt (str): the formatted prompt; + images (List[ImageType]): the list of images; + crop_mode (bool): if True, then crop the image; + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - pixel_values (torch.FloatTensor): [n_patches, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert prompt is not None and images is not None, ( + "prompt and images must be used at the same time." + ) + + sft_format = prompt + + ( + input_ids, + pixel_values, + images_crop, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + _, + ) = self.tokenize_with_images( + conversation=sft_format, + images=images, + bos=True, + eos=True, + cropping=crop_mode, + ) + + prepare = BatchFeature( + data=dict( + input_ids=input_ids, + pixel_values=pixel_values, + images_crop=images_crop, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + num_image_tokens=num_image_tokens, + ), + tensor_type="pt", + ) + return prepare + + def __call__( + self, + *, + prompt: str, + images: list[Image.Image], + crop_mode: bool = CROP_MODE, + **kwargs, + ): + prepare = self.process_one( + prompt=prompt, + images=images, + crop_mode=crop_mode, + ) + + return prepare + + def tokenize_with_images( + self, + conversation: str, + images: list[Image.Image], + bos: bool = True, + eos: bool = True, + cropping: bool = True, + ): + """Tokenize text with tags.""" + + assert conversation.count(self.image_token) == len(images) + text_splits = conversation.split(self.image_token) + images_list, images_crop_list, images_seq_mask, images_spatial_crop = ( + [], + [], + [], + [], + ) + image_shapes = [] + num_image_tokens = [] + tokenized_str = [] + for text_sep, image in zip(text_splits, images): + tokenized_sep = self.encode(text_sep, bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + image_shapes.append(image.size) + + images_crop_raw = [] + if image.size[0] <= 768 and image.size[1] <= 768: + crop_ratio = [1, 1] + elif cropping: + images_crop_raw, crop_ratio = dynamic_preprocess( + image, image_size=IMAGE_SIZE + ) + else: + crop_ratio = [1, 1] + + if self.image_size <= 768 and not cropping: + image = image.resize((self.image_size, self.image_size)) + + global_view = ImageOps.pad( + image, + (self.base_size, self.base_size), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) + images_list.append(self.image_transform(global_view)) + + num_width_tiles, num_height_tiles = crop_ratio + images_spatial_crop.append([num_width_tiles, num_height_tiles]) + + if num_width_tiles > 1 or num_height_tiles > 1: + for cropped_image in images_crop_raw: + images_crop_list.append(self.image_transform(cropped_image)) + + num_queries = math.ceil( + (self.image_size // self.patch_size) / self.downsample_ratio + ) + num_queries_base = math.ceil( + (self.base_size // self.patch_size) / self.downsample_ratio + ) + + tokenized_image = ( + [self.image_token_id] * num_queries_base + ) * num_queries_base + tokenized_image += [self.image_token_id] + if num_width_tiles > 1 or num_height_tiles > 1: + local_row = [self.image_token_id] * (num_queries * num_width_tiles) + tokenized_image += local_row * (num_queries * num_height_tiles) + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + num_image_tokens.append(len(tokenized_image)) + + """process the last text split""" + tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """add the bos and eos tokens""" + if bos: + tokenized_str = [self.bos_id] + tokenized_str + images_seq_mask = [False] + images_seq_mask + if eos: + tokenized_str = tokenized_str + [self.eos_id] + images_seq_mask = images_seq_mask + [False] + + assert len(tokenized_str) == len(images_seq_mask), ( + f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} " + f"is not equal to images_seq_mask's length {len(images_seq_mask)}." + ) + + masked_tokenized_str = [] + for token_index in tokenized_str: + if token_index != self.image_token_id: + masked_tokenized_str.append(token_index) + else: + masked_tokenized_str.append(self.ignore_id) + + assert ( + len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) + ), ( + f"tokenized_str's length {len(tokenized_str)}, " + f"input_ids' length {len(masked_tokenized_str)}, " + f"images_seq_mask's length {len(images_seq_mask)}, are not equal." + ) + + input_ids = torch.LongTensor(tokenized_str) + target_ids = torch.LongTensor(masked_tokenized_str) + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) + + # set input_ids < 0 | input_ids == self.image_token_id as ignore_id + target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( + self.ignore_id + ) + input_ids[input_ids < 0] = self.pad_id + + # Remove the ending eos token + assert input_ids[-1] == self.eos_id + input_ids = input_ids[:-1] + target_ids = target_ids[:-1] + images_seq_mask = images_seq_mask[:-1] + + if len(images_list) == 0: + pixel_values = torch.zeros((0, 3, self.base_size, self.base_size)) + images_spatial_crop = torch.zeros((0, 2), dtype=torch.long) + images_crop = torch.zeros((0, 3, self.image_size, self.image_size)) + else: + pixel_values = torch.stack(images_list, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + if images_crop_list: + images_crop = torch.stack(images_crop_list, dim=0) + else: + images_crop = torch.zeros((0, 3, self.image_size, self.image_size)) + + input_ids = input_ids.unsqueeze(0) + + return ( + input_ids, + pixel_values, + images_crop, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + image_shapes, + ) + + +AutoProcessor.register("DeepseekOCR2Processor", DeepseekOCR2Processor)