[Bugfix] Fix deepseek-ocr multi-image inference and add merge_by_field_config=True with tensor schema support (#27361)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -53,6 +54,7 @@ from vllm.transformers_utils.processors.deepseek_ocr import (
|
||||
count_tiles,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
AdapterLogitsProcessor,
|
||||
RequestLogitsProcessor,
|
||||
@@ -65,6 +67,28 @@ from .deepseek_vl2 import MlpProjector
|
||||
_IMAGE_TOKEN = "<image>"
|
||||
|
||||
|
||||
class DeepseekOCRImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- n: Number of images
|
||||
- p: Number of patches
|
||||
- base_size: Base size of the processor
|
||||
- image_size: Image size of the processor
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
data: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}),
|
||||
]
|
||||
images_crop: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}),
|
||||
]
|
||||
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
||||
|
||||
|
||||
class NoRepeatNGramLogitsProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -260,10 +284,15 @@ class DeepseekOCRMultiModalProcessor(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2)))
|
||||
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
|
||||
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
||||
images_crop=MultiModalFieldConfig.batched("image"),
|
||||
images_crop=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", patches_per_image
|
||||
),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -302,35 +331,6 @@ class DeepseekOCRMultiModalProcessor(
|
||||
)
|
||||
]
|
||||
|
||||
# TODO(Isotr0py): Check if we still need this workaround for
|
||||
# deepseek-ocr processor.
|
||||
# def _cached_apply_hf_processor(
|
||||
# self,
|
||||
# prompt: str | list[int],
|
||||
# mm_data_items: MultiModalDataItems,
|
||||
# hf_processor_mm_kwargs: Mapping[str, object],
|
||||
# tokenization_kwargs: Mapping[str, object],
|
||||
# mm_uuids: MultiModalUUIDDict | None = None,
|
||||
# ) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
# # The processor logic is different for len(images) <= 2 vs > 2
|
||||
# # Since the processing cache assumes that the processor output is
|
||||
# # invariant of how many images are passed per prompt, we only
|
||||
# # perform caching for the most common case
|
||||
# if mm_data_items.get_count("image", strict=False) > 2:
|
||||
# # This code path corresponds to the cache being disabled
|
||||
# return self._apply_hf_processor_main(
|
||||
# prompt=prompt,
|
||||
# mm_items=mm_data_items,
|
||||
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
# enable_hf_prompt_update=True,
|
||||
# )
|
||||
|
||||
# return super()._cached_apply_hf_processor(
|
||||
# prompt=prompt,
|
||||
# mm_data_items=mm_data_items,
|
||||
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
# )
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
DeepseekOCRMultiModalProcessor,
|
||||
@@ -338,6 +338,8 @@ class DeepseekOCRMultiModalProcessor(
|
||||
dummy_inputs=DeepseekOCRDummyInputsBuilder,
|
||||
)
|
||||
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# map prefix for language backbone
|
||||
@@ -389,6 +391,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.vision_model = DeepCLIPVisionTransformer(
|
||||
config=clip_vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
self.projector = MlpProjector(self.projector_config)
|
||||
@@ -426,7 +429,9 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> DeepseekOCRImagePixelInputs | None:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
|
||||
images_crop = kwargs.pop("images_crop", None)
|
||||
@@ -435,23 +440,16 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of image sizes. "
|
||||
f"Got type: {type(images_spatial_crop)}"
|
||||
)
|
||||
|
||||
if not isinstance(images_crop, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of image crop. Got type: {type(images_crop)}"
|
||||
)
|
||||
|
||||
return [pixel_values, images_crop, images_spatial_crop]
|
||||
base_size = self.vision_config.image_size
|
||||
return DeepseekOCRImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
images_crop=images_crop,
|
||||
images_spatial_crop=images_spatial_crop,
|
||||
resolve_bindings={
|
||||
"base_size": base_size,
|
||||
},
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
@@ -518,10 +516,13 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> NestedTensors:
|
||||
images_in_this_batch = []
|
||||
|
||||
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
|
||||
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
|
||||
images_crop = images_crop.split(patches_per_image.tolist())
|
||||
for jdx in range(images_spatial_crop.size(0)):
|
||||
patches = images_crop[jdx][0].to(torch.bfloat16)
|
||||
image_ori = pixel_values[jdx]
|
||||
crop_shape = images_spatial_crop[jdx][0]
|
||||
patches = images_crop[jdx]
|
||||
image_ori = pixel_values[[jdx]]
|
||||
crop_shape = images_spatial_crop[jdx]
|
||||
|
||||
global_features = self._encode_global_features(image_ori)
|
||||
local_features = self._encode_local_features(patches, crop_shape)
|
||||
@@ -540,10 +541,12 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return images_in_this_batch
|
||||
|
||||
def _process_image_input(self, image_input) -> torch.Tensor:
|
||||
pixel_values = image_input[0].to(torch.bfloat16)
|
||||
images_crop = image_input[1]
|
||||
images_spatial_crop = image_input[2].to(dtype=torch.long)
|
||||
def _process_image_input(
|
||||
self, image_input: DeepseekOCRImagePixelInputs
|
||||
) -> torch.Tensor:
|
||||
pixel_values = image_input.data
|
||||
images_crop = image_input.images_crop
|
||||
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)
|
||||
|
||||
vision_features = self._pixel_values_to_embedding(
|
||||
pixel_values=pixel_values,
|
||||
|
||||
Reference in New Issue
Block a user