diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 297d98142..3159d3bd1 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -615,6 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | +| `DonutForConditionalGeneration`^ | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py new file mode 100644 index 000000000..d2ba27cd1 --- /dev/null +++ b/examples/offline_inference/dolphin.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import os +from dataclasses import dataclass + +import cv2 +import numpy as np +import regex as re +from PIL import Image +from transformers import DonutProcessor + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.multimodal.utils import fetch_image + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +@dataclass +class ImageDimensions: + original_w: int + original_h: int + padded_w: int + padded_h: int + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def map_to_original_coordinates( + x1, y1, x2, y2, dims: ImageDimensions +) -> tuple[int, int, int, int]: + try: + top = (dims.padded_h - dims.original_h) // 2 + left = (dims.padded_w - dims.original_w) // 2 + orig_x1 = max(0, x1 - left) + orig_y1 = max(0, y1 - top) + orig_x2 = min(dims.original_w, x2 - left) + orig_y2 = min(dims.original_h, y2 - top) + if orig_x2 <= orig_x1: + orig_x2 = min(orig_x1 + 1, dims.original_w) + if orig_y2 <= orig_y1: + orig_y2 = min(orig_y1 + 1, dims.original_h) + return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) + except Exception as e: + print(f"map_to_original_coordinates error: {str(e)}") + return 0, 0, min(100, dims.original_w), min(100, dims.original_h) + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2): + if isinstance(image, str): + image = cv2.imread(image) + img_h, img_w = image.shape[:2] + new_boxes = [] + for box in boxes: + best_box = copy.deepcopy(box) + + def check_edge(img, current_box, i, is_vertical): + edge = current_box[i] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, binary = cv2.threshold( + gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU + ) + if is_vertical: + line = binary[current_box[1] : current_box[3] + 1, edge] + else: + line = binary[edge, current_box[0] : current_box[2] + 1] + transitions = np.abs(np.diff(line)) + return np.sum(transitions) / len(transitions) + + edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] + current_box = copy.deepcopy(box) + current_box[0] = min(max(current_box[0], 0), img_w - 1) + current_box[1] = min(max(current_box[1], 0), img_h - 1) + current_box[2] = min(max(current_box[2], 0), img_w - 1) + current_box[3] = min(max(current_box[3], 0), img_h - 1) + + for i, direction, is_vertical in edges: + best_score = check_edge(image, current_box, i, is_vertical) + if best_score <= threshold: + continue + for step in range(max_pixels): + current_box[i] += direction + if i == 0 or i == 2: + current_box[i] = min(max(current_box[i], 0), img_w - 1) + else: + current_box[i] = min(max(current_box[i], 0), img_h - 1) + score = check_edge(image, current_box, i, is_vertical) + if score < best_score: + best_score = score + best_box = copy.deepcopy(current_box) + if score <= threshold: + break + new_boxes.append(best_box) + return new_boxes + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): + try: + x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) + x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) + x1, y1, x2, y2 = new_boxes[0] + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + if previous_box is not None: + prev_x1, prev_y1, prev_x2, prev_y2 = previous_box + if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): + y1 = prev_y2 + y1 = min(y1, dims.padded_h - 1) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_previous_box = [x1, y1, x2, y2] + orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( + x1, y1, x2, y2, dims + ) + return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box + except Exception as e: + print(f"process_coordinates error: {str(e)}") + orig_x1, orig_y1, orig_x2, orig_y2 = ( + 0, + 0, + min(100, dims.original_w), + min(100, dims.original_h), + ) + return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: + try: + image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + original_h, original_w = image_cv.shape[:2] + max_size = max(original_h, original_w) + top = (max_size - original_h) // 2 + bottom = max_size - original_h - top + left = (max_size - original_w) // 2 + right = max_size - original_w - left + padded_image = cv2.copyMakeBorder( + image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0) + ) + padded_h, padded_w = padded_image.shape[:2] + dimensions = ImageDimensions( + original_w=original_w, + original_h=original_h, + padded_w=padded_w, + padded_h=padded_h, + ) + return padded_image, dimensions + except Exception as e: + print(f"prepare_image error: {str(e)}") + h, w = image.height, image.width + dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) + return np.zeros((h, w, 3), dtype=np.uint8), dimensions + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def parse_layout_string(bbox_str): + """Parse layout string using regular expressions""" + pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" + matches = re.finditer(pattern, bbox_str) + + parsed_results = [] + for match in matches: + coords = [float(match.group(i)) for i in range(1, 5)] + label = match.group(5).strip() + parsed_results.append((coords, label)) + + return parsed_results + + +model_id = "ByteDance/Dolphin" + +# The input image size for Dolphin is 896 x 896, +# and the patch_size is 4 x 4. +# Therefore, the initial number of patches is: +# Height: 896 / 4 = 224 patches +# Width: 896 / 4 = 224 patches + +# The Dolphin model uses a staged downsampling approach, +# defined by the "depths": [2, 2, 14, 2] configuration. +# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, +# which halves the feature map's dimensions (dividing both height and width by 2). +# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112. +# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56. +# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28. + +# Because vLLM needs to fill the image features with an encoder_prompt, +# and the encoder_prompt will have `` tokens added when tokenized, +# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783. +encoder_prompt = "".join(["0"] * 783) +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=2048, +) + +processor = DonutProcessor.from_pretrained(model_id) +llm = LLM( + model=model_id, + dtype="float16", + max_num_seqs=8, + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, +) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--image_path", type=str, default=None, help="Path to a local image file." +) +args = parser.parse_args() + +if args.image_path: + if not os.path.exists(args.image_path): + raise FileNotFoundError(f"Error: File not found at {args.image_path}") + image = Image.open(args.image_path).convert("RGB") +else: + image = fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) + + +prompt = "Parse the reading order of this document. " +decoder_prompt = f"{prompt}" +decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] +) +enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), + decoder_prompt=decoder_prompt_tokens, +) +layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) +layout_result_str = layout_outputs[0].outputs[0].text +print(f"Layout analysis output:\n{layout_result_str}") + +padded_image, dims = prepare_image(image) +layout_results = parse_layout_string(layout_result_str) +text_table_elements = [] +previous_box = None +reading_order = 0 +for bbox_coords, label in layout_results: + if label == "fig": + continue + try: + x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( + process_coordinates(bbox_coords, padded_image, dims, previous_box) + ) + cropped = padded_image[y1:y2, x1:x2] + if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: + pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) + prompt_ocr = ( + "Parse the table in the image. " + if label == "tab" + else "Read text in the image. " + ) + text_table_elements.append( + { + "crop": pil_crop, + "prompt": prompt_ocr, + "reading_order": reading_order, + } + ) + reading_order += 1 + except Exception as e: + print(f"Error processing bbox (label: {label}): {str(e)}") + continue + +if text_table_elements: + batch_prompts = [] + for elem in text_table_elements: + decoder_prompt_str = f"{elem['prompt']}" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer( + decoder_prompt_str, add_special_tokens=False + )["input_ids"] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} + ), + decoder_prompt=decoder_prompt_tokens, + ) + batch_prompts.append(enc_dec_prompt) + batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params) + for i, output in enumerate(batch_outputs): + text_table_elements[i]["text"] = output.outputs[0].text.strip() + +print("------" * 8) +text_table_elements.sort(key=lambda x: x["reading_order"]) +for elem in text_table_elements: + print(elem.get("text", "")) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index d27a902ed..655f9f3fc 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -13,6 +13,7 @@ from typing import NamedTuple from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] +def run_donut(): + engine_args = EngineArgs( + model="naver-clova-ix/donut-base-finetuned-docvqa", + max_num_seqs=2, + limit_mm_per_prompt={"image": 1}, + dtype="float16", + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, + ) + + # The input image size for donut-base-finetuned-docvqa is 2560 x 1920, + # and the patch_size is 4 x 4. + # Therefore, the initial number of patches is: + # Height: 1920 / 4 = 480 patches + # Width: 2560 / 4 = 640 patches + # The Swin model uses a staged downsampling approach, + # defined by the "depths": [2, 2, 14, 2] configuration. + # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, + # which halves the feature map's dimensions (dividing both height and width by 2). + # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. + # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. + # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. + # Because vLLM needs to fill the image features with an encoder_prompt, + # and the encoder_prompt will have `` tokens added when tokenized, + # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. + prompts = [ + { + "encoder_prompt": { + "prompt": "".join(["$"] * 4799), + "multi_modal_data": { + "image": fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) # noqa: E501 + }, + }, + "decoder_prompt": "What time is the coffee break?", # noqa: E501 + }, + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_florence2(): engine_args = EngineArgs( model="microsoft/Florence-2-large", @@ -118,6 +163,7 @@ def run_whisper(): model_example_map = { + "donut": run_donut, "florence2": run_florence2, "mllama": run_mllama, "whisper": run_whisper, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index adc8b2510..a604d11f0 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -160,6 +160,7 @@ def _test_processing_correctness( # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. _ADD_SPECIAL_TOKENS_OVERRIDES = { + "donut": False, "mllama": False, "ovis": False, "ovis2_5": False, @@ -270,6 +271,7 @@ def _test_processing_correctness_one( "facebook/chameleon-7b", "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", + "naver-clova-ix/donut-base-finetuned-docvqa", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", diff --git a/tests/models/registry.py b/tests/models/registry.py index 25dbbd7fa..b34c6f2e5 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -513,6 +513,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { is_available_online=False, ), # [Encoder-decoder] + "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 + extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bbe958351..dbf8d3ba5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1822,7 +1822,7 @@ class LLMEngine: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py new file mode 100644 index 000000000..b1f6a0af6 --- /dev/null +++ b/vllm/model_executor/models/donut.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature, NougatProcessor + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder +from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsMultiModal, + SupportsV0Only) +from vllm.model_executor.models.swin import SwinModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, + _flatten_embeddings, flatten_bn) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptIndexTargets, PromptInsertion, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder + + +class MBartDecoderWrapper(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.model = MBartDecoderWrapper(vllm_config=vllm_config, + prefix=f"{prefix}.model") + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.vocab_size = config.vocab_size + self.lm_head = BartParallelLMHead(self.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + Returns: + Output torch.Tensor + """ + + return self.model(decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "final_logits_bias" in name: + continue + # if self.config.tie_word_embeddings and "embed_tokens" in name: + # continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DonutImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channel, height, width)""" + + +class DonutProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + return 1 + + +class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_hf_config( + ).encoder.image_size + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + @property + def pad_dummy_encoder_prompt(self) -> bool: + return True + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + hf_processor = self.info.get_hf_processor() + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs) + if isinstance(hf_processor, NougatProcessor): + processed_outputs["input_ids"] = processed_outputs["labels"] + else: + tokenizer = hf_processor.tokenizer + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + pad_token_id = tokenizer.pad_token_id + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor, + info=DonutProcessingInfo, + dummy_inputs=DonutDummyInputsBuilder) +class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config + + self.config = config + self.vision_config = config.encoder + self.processor_config = processor_config + self.encoder = SwinModel(config=config.encoder) + + self.decoder = DonutLanguageForConditionalGeneration( + vllm_config=vllm_config.with_hf_config(config.decoder), + prefix=f"{prefix}.decoder", + ) + self.pad_token_id = config.pad_token_id + + def _validate_pixel_values( + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: + + # size = self.processor_config["size"] + h, w = self.config.encoder.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + raise ValueError( + "The expected shape of pixel values per batch " + f"is {expected_dims}. You supplied {actual_dims}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + return DonutImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: DonutImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + dtype = next(self.encoder.parameters()).dtype + pixel_values = pixel_values.to(dtype) + return self.encoder(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + ) -> torch.Tensor: + return _flatten_embeddings(multimodal_embeddings) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + + inputs_embeds = None + if encoder_input_ids.numel() > 0: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + + hidden_states = self.decoder(input_ids, + positions, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 465c25f09..ebf78771e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -252,6 +252,7 @@ _MULTIMODAL_MODELS = { "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] + "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py new file mode 100644 index 000000000..30b441f5b --- /dev/null +++ b/vllm/model_executor/models/swin.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import SwinConfig +from transformers.models.swin.modeling_swin import SwinEmbeddings +from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer +from transformers.models.swin.modeling_swin import SwinPatchMerging +from transformers.pytorch_utils import meshgrid + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +class SwinSelfAttention(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of " + f"attention heads ({num_heads})") + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = (window_size if isinstance(window_size, Iterable) + else (window_size, window_size)) + self.scale = self.attention_head_size**-0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, + None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + + self.relative_position_index = nn.Parameter(relative_position_index, + requires_grad=False) + + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.attention_head_size, + total_num_heads=self.num_attention_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, ...]: + batch_size, dim, num_channels = hidden_states.shape + + qkv_output, _ = self.qkv(hidden_states) + query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1) + + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + query_layer = self.transpose_for_scores(query_layer) + + attention_scores = self._get_rel_pos_bias() + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_mask_expanded = attention_mask.view( + 1, mask_shape, 1, dim, + dim).expand(batch_size // mask_shape, mask_shape, + self.num_attention_heads, dim, dim) + attention_scores = attention_scores + \ + attention_mask_expanded.unsqueeze( + 1).unsqueeze(0) + attention_scores = attention_scores.view(-1, + self.num_attention_heads, + dim, dim) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_scores, + dropout_p=0., + ) + attention_probs = None + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + return outputs + + +class SwinSelfOutput(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.self = SwinSelfAttention(config, + dim, + num_heads, + window_size, + quant_config=quant_config, + prefix=f"{prefix}.self") + self.output = SwinSelfOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, + output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, ) + self_outputs[1:] + return outputs + + +class SwinIntermediate(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = ColumnParallelLinear(dim, + int(config.mlp_ratio * dim), + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = RowParallelLinear(int(config.mlp_ratio * dim), + dim, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + return hidden_states + + +class SwinLayer(HFSwinLayer): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path_rate, + shift_size=shift_size, + ) + + self.attention = SwinAttention(config, + dim, + num_heads, + window_size=self.window_size, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.intermediate = SwinIntermediate(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + self.output = SwinOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + + +class SwinStage(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + depth: int, + num_heads: int, + drop_path: list[float], + downsample: Optional[SwinPatchMerging] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList([ + SwinLayer(config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[layer_idx], + shift_size=0 if + (layer_idx % 2 == 0) else config.window_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, + dim=dim, + norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + + 1) // 2 + output_dimensions = (height, width, height_downsampled, + width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, + input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, + output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + + def __init__( + self, + config: SwinConfig, + grid_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [ + x.item() for x in torch.linspace( + 0, config.drop_path_rate, sum(config.depths), device="cpu") + ] + self.layers = nn.ModuleList([ + SwinStage(config=config, + dim=int(config.embed_dim * 2**layer_idx), + input_resolution=(grid_size[0] // (2**layer_idx), + grid_size[1] // (2**layer_idx)), + depth=config.depths[layer_idx], + num_heads=config.num_heads[layer_idx], + drop_path=dpr[sum(config.depths[:layer_idx] + ):sum(config.depths[:layer_idx + 1])], + downsample=SwinPatchMerging if + (layer_idx < self.num_layers - 1) else None, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(self.num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + return hidden_states + + +class SwinModel(nn.Module): + config_class: SwinConfig + + def __init__( + self, + config: SwinConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2**(self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, + self.embeddings.patch_grid, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> tuple[torch.Tensor]: + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv", "query", "q"), + ("qkv", "key", "k"), + ("qkv", "value", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 2da9b4c72..ea2efbdd8 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]): if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper allows total_len > seq_len. + # NOTE: Whisper and Donut allows total_len > seq_len. elif total_len > seq_len and not envs.VLLM_USE_V1: # `max_num_batched_tokens` is defined by `SchedulerConfig` logger.warning_once( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 69f8e531e..219857dc7 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -389,7 +389,7 @@ class Processor: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = (