[Model] Use merge_by_field_config for MM models (D-F) (#26076)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -42,34 +42,38 @@ from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
||||
DotsVisionConfig)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .vision import run_dp_sharded_mrope_vision_model
|
||||
|
||||
IMAGE_TOKEN = "<|imgpad|>"
|
||||
|
||||
|
||||
class DotsOCRImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values", "image_grid_thw"]
|
||||
|
||||
pixel_values: torch.Tensor
|
||||
image_grid_thw: torch.Tensor
|
||||
|
||||
|
||||
class DotsOCRImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds", "image_grid_thw"]
|
||||
image_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- List[`torch.Tensor`]: A list of tensors holding all images' features.
|
||||
Each tensor holds an image's features.
|
||||
- `torch.Tensor`: A tensor holding all images' features
|
||||
(concatenation of all images' feature tensors).
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the images.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
class DotsOCRImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: The total number of patches over each image over each prompt in
|
||||
the batch
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
|
||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
|
||||
class DotsOCRImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- nf: Number of image features
|
||||
- hs: Hidden size
|
||||
- ni: Number of images
|
||||
"""
|
||||
type: Literal["image_embeds"]
|
||||
|
||||
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
|
||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
|
||||
DotsOCRImageInputs = Union[DotsOCRImagePixelInputs,
|
||||
@@ -654,6 +658,8 @@ class DotsVisionTransformer(nn.Module):
|
||||
)
|
||||
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
".attn.qkv_proj.": ".attn.qkv.",
|
||||
@@ -709,22 +715,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
architectures=["Qwen2ForCausalLM"],
|
||||
)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return torch.concat(list(mm_input))
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@@ -735,28 +725,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return DotsOCRImagePixelInputs(type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||
image_embeds, "image embeds")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return DotsOCRImageEmbeddingInputs(type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
image_grid_thw=image_grid_thw)
|
||||
|
||||
Reference in New Issue
Block a user