Migrate Qwen2 inputs to TensorSchema (#23475)
Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -27,7 +27,7 @@
|
||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -64,6 +64,7 @@ from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
@@ -80,84 +81,125 @@ logger = init_logger(__name__)
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
class Qwen2_5_VLImagePixelInputs(TypedDict):
|
||||
class Qwen2_5_VLImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: Number of patches
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
|
||||
Historical context:
|
||||
- pixel_values shape: (num_patches, num_channels * patch_size *
|
||||
patch_size)
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
formatnum_channels * patch_size * patch_size
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches, num_channels * patch_size * patch_size)`
|
||||
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "cps"),
|
||||
]
|
||||
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
Dimensions:
|
||||
- nf: Number of image features
|
||||
- hs: Hidden size
|
||||
- ni: Number of images
|
||||
|
||||
Historical context:
|
||||
- image_embeds shape: (num_image_features, hidden_size)
|
||||
- num_image_features varies based on the number and resolution of the
|
||||
images.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
image_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all images' features.
|
||||
Each tensor holds an image's features.
|
||||
- `torch.Tensor`: A tensor holding all images' features
|
||||
(concatenation of all images' feature tensors).
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the images.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
image_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
|
||||
Qwen2_5_VLImageEmbeddingInputs]
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoPixelInputs(TypedDict):
|
||||
class Qwen2_5_VLVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: Number of patches
|
||||
- nv: Number of videos
|
||||
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||
patch_size
|
||||
|
||||
Historical context:
|
||||
- pixel_values_videos shape: (num_patches, num_channels *
|
||||
temporal_patch_size * patch_size * patch_size)
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
- second_per_grid_ts: The video time interval (in seconds) for each
|
||||
grid along the temporal dimension in the 3D position IDs. Returned
|
||||
when `videos` is not `None`.
|
||||
"""
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches,
|
||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
||||
|
||||
pixel_values_videos: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "ctps"),
|
||||
]
|
||||
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
second_per_grid_ts: Annotated[
|
||||
Optional[torch.Tensor],
|
||||
TensorShape("nv"),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
Dimensions:
|
||||
- nf: Number of video features
|
||||
- hs: Hidden size
|
||||
- nv: Number of videos
|
||||
|
||||
Historical context:
|
||||
- video_embeds shape: (num_video_features, hidden_size)
|
||||
- num_video_features varies based on the number and resolution of the
|
||||
videos.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
|
||||
second_per_grid_ts: torch.Tensor
|
||||
"""
|
||||
The video time interval (in seconds) for each grid along the temporal
|
||||
dimension in the 3D position IDs. Returned when `videos` is not `None`.
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
|
||||
type: Literal["video_embeds"]
|
||||
video_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
|
||||
Each tensor holds an video's features.
|
||||
- `torch.Tensor`: A tensor holding all videos' features
|
||||
(concatenation of all videos' feature tensors).
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the videos.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
video_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs,
|
||||
@@ -936,10 +978,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw)
|
||||
@@ -950,9 +988,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return Qwen2_5_VLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@@ -973,7 +1008,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
pixel_values_videos, "video pixel values")
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2:
|
||||
second_per_grid_ts = second_per_grid_ts.squeeze(-1)
|
||||
return Qwen2_5_VLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@@ -987,9 +1023,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
if not isinstance(video_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of video embeddings. "
|
||||
f"Got type: {type(video_embeds)}")
|
||||
return Qwen2_5_VLVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
|
||||
Reference in New Issue
Block a user