[Model] Use merge_by_field_config for MM models (D-F) (#26076)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-02 23:17:35 +08:00
committed by GitHub
parent 7d6fb905d9
commit cc253b73d3
4 changed files with 102 additions and 180 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@@ -42,34 +42,38 @@ from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .vision import run_dp_sharded_mrope_vision_model
IMAGE_TOKEN = "<|imgpad|>"
class DotsOCRImagePixelInputs(TypedDict):
type: Literal["pixel_values", "image_grid_thw"]
pixel_values: torch.Tensor
image_grid_thw: torch.Tensor
class DotsOCRImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds", "image_grid_thw"]
image_embeds: torch.Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
class DotsOCRImagePixelInputs(TensorSchema):
"""
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
"""
type: Literal["pixel_values"]
image_grid_thw: torch.Tensor
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
class DotsOCRImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- nf: Number of image features
- hs: Hidden size
- ni: Number of images
"""
type: Literal["image_embeds"]
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
DotsOCRImageInputs = Union[DotsOCRImagePixelInputs,
@@ -654,6 +658,8 @@ class DotsVisionTransformer(nn.Module):
)
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
".attn.qkv_proj.": ".attn.qkv.",
@@ -709,22 +715,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
architectures=["Qwen2ForCausalLM"],
)
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@@ -735,28 +725,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return DotsOCRImagePixelInputs(type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return DotsOCRImageEmbeddingInputs(type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw)