[Model] Define merge_by_field_config MM interface (R-T) (#26260)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,7 +4,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from itertools import product
|
||||
from math import ceil, sqrt
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -44,28 +44,48 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
class Step3VLImagePixelInputs(TypedDict):
|
||||
class Step3VLImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
- bnp: Batch size * number of images * number of patches
|
||||
- hp: Height of patch
|
||||
- wp: Width of patch
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
patch_pixel_values: Optional[torch.Tensor]
|
||||
num_patches: list[int]
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
patch_pixel_values: Annotated[
|
||||
Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp")
|
||||
]
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
class Step3VLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
image_embeds: torch.Tensor
|
||||
class Step3VLImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- f: Image feature size
|
||||
- h: Hidden size (must match the hidden size of language model backbone)
|
||||
"""
|
||||
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
|
||||
|
||||
|
||||
Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
|
||||
@@ -895,6 +915,8 @@ class Step3VisionTransformer(nn.Module):
|
||||
dummy_inputs=Step3VLDummyInputsBuilder,
|
||||
)
|
||||
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.": "language_model.model.",
|
||||
@@ -982,41 +1004,22 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
if pixel_values.dim() >= 3:
|
||||
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
|
||||
if patch_pixel_values is not None:
|
||||
patch_pixel_values = flatten_bn(patch_pixel_values, concat=True)
|
||||
patch_pixel_values = patch_pixel_values.view(
|
||||
-1, *patch_pixel_values.shape[-3:]
|
||||
)
|
||||
# Handle empty patch_pixel_values by setting to None
|
||||
if patch_pixel_values.shape[0] == 0:
|
||||
patch_pixel_values = None
|
||||
num_patches = flatten_bn(num_patches, concat=True).tolist()
|
||||
|
||||
return Step3VLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values.to(self.dtype).to(self.device),
|
||||
patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device)
|
||||
pixel_values=pixel_values.to(self.dtype),
|
||||
patch_pixel_values=patch_pixel_values.to(self.dtype)
|
||||
if patch_pixel_values is not None
|
||||
else None,
|
||||
num_patches=num_patches,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
|
||||
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected shape for image_embeds: {image_embeds.shape}"
|
||||
)
|
||||
|
||||
return Step3VLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds.to(self.dtype).to(self.device),
|
||||
image_embeds=image_embeds.to(self.dtype),
|
||||
)
|
||||
return None
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
B, P = image_features.shape[:2]
|
||||
|
||||
Reference in New Issue
Block a user