[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support (#27361)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-10-23 08:15:38 +08:00
committed by GitHub
parent b4fda58a2d
commit 2566dca2a9
4 changed files with 112 additions and 66 deletions

View File

@@ -4,6 +4,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal
import torch
import torch.nn as nn
@@ -53,6 +54,7 @@ from vllm.transformers_utils.processors.deepseek_ocr import (
count_tiles,
)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
RequestLogitsProcessor,
@@ -65,6 +67,28 @@ from .deepseek_vl2 import MlpProjector
_IMAGE_TOKEN = "<image>"
class DeepseekOCRImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of images
- p: Number of patches
- base_size: Base size of the processor
- image_size: Image size of the processor
"""
type: Literal["pixel_values"]
data: Annotated[
torch.Tensor,
TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}),
]
images_crop: Annotated[
torch.Tensor,
TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}),
]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
class NoRepeatNGramLogitsProcessor:
def __init__(
self,
@@ -260,10 +284,15 @@ class DeepseekOCRMultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2)))
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
images_spatial_crop=MultiModalFieldConfig.batched("image"),
images_crop=MultiModalFieldConfig.batched("image"),
images_crop=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image
),
)
def _get_prompt_updates(
@@ -302,35 +331,6 @@ class DeepseekOCRMultiModalProcessor(
)
]
# TODO(Isotr0py): Check if we still need this workaround for
# deepseek-ocr processor.
# def _cached_apply_hf_processor(
# self,
# prompt: str | list[int],
# mm_data_items: MultiModalDataItems,
# hf_processor_mm_kwargs: Mapping[str, object],
# tokenization_kwargs: Mapping[str, object],
# mm_uuids: MultiModalUUIDDict | None = None,
# ) -> tuple[list[int], MultiModalKwargs, bool]:
# # The processor logic is different for len(images) <= 2 vs > 2
# # Since the processing cache assumes that the processor output is
# # invariant of how many images are passed per prompt, we only
# # perform caching for the most common case
# if mm_data_items.get_count("image", strict=False) > 2:
# # This code path corresponds to the cache being disabled
# return self._apply_hf_processor_main(
# prompt=prompt,
# mm_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# enable_hf_prompt_update=True,
# )
# return super()._cached_apply_hf_processor(
# prompt=prompt,
# mm_data_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# )
@MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor,
@@ -338,6 +338,8 @@ class DeepseekOCRMultiModalProcessor(
dummy_inputs=DeepseekOCRDummyInputsBuilder,
)
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# map prefix for language backbone
@@ -389,6 +391,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.projector = MlpProjector(self.projector_config)
@@ -426,7 +429,9 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(self, **kwargs: object):
def _parse_and_validate_image_input(
self, **kwargs: object
) -> DeepseekOCRImagePixelInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None)
@@ -435,23 +440,16 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image sizes. "
f"Got type: {type(images_spatial_crop)}"
)
if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image crop. Got type: {type(images_crop)}"
)
return [pixel_values, images_crop, images_spatial_crop]
base_size = self.vision_config.image_size
return DeepseekOCRImagePixelInputs(
type="pixel_values",
data=pixel_values,
images_crop=images_crop,
images_spatial_crop=images_spatial_crop,
resolve_bindings={
"base_size": base_size,
},
)
raise AssertionError("This line should be unreachable.")
@@ -518,10 +516,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) -> NestedTensors:
images_in_this_batch = []
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
images_crop = images_crop.split(patches_per_image.tolist())
for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx][0].to(torch.bfloat16)
image_ori = pixel_values[jdx]
crop_shape = images_spatial_crop[jdx][0]
patches = images_crop[jdx]
image_ori = pixel_values[[jdx]]
crop_shape = images_spatial_crop[jdx]
global_features = self._encode_global_features(image_ori)
local_features = self._encode_local_features(patches, crop_shape)
@@ -540,10 +541,12 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return images_in_this_batch
def _process_image_input(self, image_input) -> torch.Tensor:
pixel_values = image_input[0].to(torch.bfloat16)
images_crop = image_input[1]
images_spatial_crop = image_input[2].to(dtype=torch.long)
def _process_image_input(
self, image_input: DeepseekOCRImagePixelInputs
) -> torch.Tensor:
pixel_values = image_input.data
images_crop = image_input.images_crop
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)
vision_features = self._pixel_values_to_embedding(
pixel_values=pixel_values,