Migrate DonutImagePixelInputs to TensorSchema (#23509)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck
2025-08-24 22:02:15 -07:00
committed by GitHub
parent a5203d04df
commit 787cdb3829

View File

@@ -3,7 +3,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
PromptIndexTargets, PromptInsertion, PromptIndexTargets, PromptInsertion,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.utils.tensor_schema import TensorSchema, TensorShape
class MBartDecoderWrapper(nn.Module): class MBartDecoderWrapper(nn.Module):
@@ -132,10 +133,16 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
return loaded_params return loaded_params
class DonutImagePixelInputs(TypedDict): class DonutImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
"""Shape: (batch_size, num_channel, height, width)"""
class DonutProcessingInfo(BaseProcessingInfo): class DonutProcessingInfo(BaseProcessingInfo):
@@ -275,27 +282,6 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
) )
self.pad_token_id = config.pad_token_id self.pad_token_id = config.pad_token_id
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
# size = self.processor_config["size"]
h, w = self.config.encoder.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
raise ValueError(
"The expected shape of pixel values per batch "
f"is {expected_dims}. You supplied {actual_dims}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(self, **kwargs: object): def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[list[list[torch.Tensor]], pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor], list[torch.Tensor],
@@ -314,11 +300,14 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
"Both pixel values and image embeds are provided.") "Both pixel values and image embeds are provided.")
if pixel_values is not None: if pixel_values is not None:
return DonutImagePixelInputs( h, w = self.config.encoder.image_size
type="pixel_values", return DonutImagePixelInputs(type="pixel_values",
data=self._validate_pixel_values( data=flatten_bn(pixel_values,
flatten_bn(pixel_values, concat=True)), concat=True),
) resolve_bindings={
"h": h,
"w": w,
})
if image_embeds is not None: if image_embeds is not None:
raise NotImplementedError raise NotImplementedError