[Model] Support nested structures for TensorSchema (#26212)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-04 16:20:32 +08:00
committed by GitHub
parent d3d649efec
commit 44ea85137a
5 changed files with 274 additions and 292 deletions

View File

@@ -2,27 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# copied from : https://github.com/huggingface/transformers
import ast
import sys
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from itertools import chain
from typing import Any, Literal, Optional, TypedDict, Union
from itertools import accumulate
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import PIL
from einops import rearrange
from PIL import Image
if sys.version_info >= (3, 11):
import typing
Unpack = typing.Unpack
else:
import typing_extensions
Unpack = typing_extensions.Unpack
import torch
import torch.nn as nn
from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
@@ -42,11 +31,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
from .vision import get_vision_encoder_info
EOT = "<|endofturn|>"
@@ -69,28 +60,42 @@ def get_num_combined_frames(
return num_canvases + (leftover_frames > 0)
class HCXVisionMultimodalPixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values_images: list[torch.Tensor]
class HCXVisionImagePixelInputs(TensorSchema):
"""
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
Dimensions:
- n: Number of images
- g: Number of grids
- c: Number of channels (3)
- h: Height
- w: Width
"""
image_sizes_images: list[tuple[Union[int, float]]]
"""
Shape: `[(height, width), ...]`
"""
vision_query_lengths_images: list[Union[int, float]]
pixel_values_videos: list[tuple[Union[int, float]]]
"""
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
"""
vision_query_lengths_videos: list[Union[int, float]]
type: Literal["pixel_values"] = "pixel_values"
pixel_values_images: Annotated[
list[torch.Tensor],
TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})]
image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]
HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs]
HCXVisionImageInputs = HCXVisionImagePixelInputs
class HCXVisionVideoPixelInputs(TensorSchema):
"""
Dimensions:
- n: Number of videos
- f: Number of frames
- g: Number of grids
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
pixel_values_videos: Annotated[
list[list[torch.Tensor]],
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})]
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
class HCXVisionProcessingInfo(BaseProcessingInfo):
@@ -191,27 +196,9 @@ class HCXVisionMultiModalProcessor(
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
def replace_multimodal_token(
token_ids: torch.Tensor,
target_token: int,
repeats: list[int],
):
output = list[int]()
_repeats_idx = 0
for token_id in token_ids:
if token_id == target_token:
output += [token_id.item()] * repeats[_repeats_idx]
_repeats_idx += 1
else:
output += [token_id.item()]
return torch.tensor(output, device=token_ids.device)
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
if video_arr.dtype == np.uint8:
continue
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
if video_arr.dtype != np.uint8:
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
processed_outputs = self.info.ctx.call_hf_processor(
hf_processor=self.info.get_hf_processor(**mm_kwargs),
@@ -223,20 +210,16 @@ class HCXVisionMultiModalProcessor(
) # text-only
if len(mm_data) > 0:
images = mm_data.get("images")
videos = mm_data.get("videos")
# batchify input as a single item
images = mm_data.get("images", None)
batched_images = None if images is None else [images]
# list of video in single conversation
videos = mm_data.get("videos", None)
batched_videos = None if videos is None else [videos]
_processed_outputs = self.info.ctx.call_hf_processor(
hf_processor=self.info.get_hf_processor(**mm_kwargs),
data=dict(
text=None,
images=batched_images,
videos=batched_videos,
images=None if images is None else [images],
videos=None if videos is None else [videos],
),
) # mm-only
@@ -246,51 +229,43 @@ class HCXVisionMultiModalProcessor(
_processed_outputs[k] = v[0]
if images:
tokenizer = self.info.get_tokenizer()
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
processed_outputs["input_ids"] = torch.stack([
replace_multimodal_token(
token_ids=_input_ids,
target_token=image_token_id,
repeats=_processed_outputs[
"vision_query_lengths_images"],
) for _input_ids in processed_outputs["input_ids"]
],
dim=0)
_processed_outputs["image_sizes_images"] = torch.tensor(
_processed_outputs["image_sizes_images"])
_processed_outputs[
"vision_query_lengths_images"] = torch.tensor(
_processed_outputs["vision_query_lengths_images"])
if videos:
_num_per_videos = [
get_num_combined_frames(len(video)) for video in videos
_idx_per_video = [
0, *accumulate(
get_num_combined_frames(len(video))
for video in videos)
]
_processed_outputs["pixel_values_videos"] = [
_processed_outputs["pixel_values_videos"]
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
for _i in range(len(videos))
[_idx_per_video[i]:_idx_per_video[i + 1]]
for i in range(len(videos))
]
_processed_outputs["vision_query_lengths_videos"] = [
_processed_outputs["vision_query_lengths_videos"]
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
for _i in range(len(videos))
torch.tensor(
_processed_outputs["vision_query_lengths_videos"]
[_idx_per_video[i]:_idx_per_video[i + 1]])
for i in range(len(videos))
]
tokenizer = self.info.get_tokenizer()
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
processed_outputs["input_ids"] = torch.stack([
replace_multimodal_token(
token_ids=_input_ids,
target_token=video_token_id,
repeats=[
sum(lens) for lens in
_processed_outputs["vision_query_lengths_videos"]
],
) for _input_ids in processed_outputs["input_ids"]
],
dim=0)
processed_outputs.update(_processed_outputs)
return processed_outputs
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
@@ -311,11 +286,11 @@ class HCXVisionMultiModalProcessor(
out_item = out_mm_kwargs[modality][item_idx]
if modality == "image":
lens = out_item["vision_query_lengths_images"].data
lens = out_item["vision_query_lengths_images"].data.tolist()
num_tokens = self.info.get_num_image_tokens(
vision_query_length=lens)
elif modality == "video":
lens = out_item["vision_query_lengths_videos"].data
lens = out_item["vision_query_lengths_videos"].data.tolist()
num_tokens = self.info.get_num_video_tokens(
vision_query_length=lens)
else:
@@ -343,26 +318,11 @@ class HCXVisionMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
# image
pixel_values_images=MultiModalFieldConfig.batched("image"),
image_sizes_images=MultiModalFieldConfig.batched("image"),
vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
num_queries_vis_abstractors_images=MultiModalFieldConfig.batched(
"image"),
num_queries_vis_abstractors_slow_images=MultiModalFieldConfig.
batched("image"),
first_last_frames_slows_images=MultiModalFieldConfig.batched(
"image"),
# video
pixel_values_videos=MultiModalFieldConfig.batched("video"),
image_sizes_videos=MultiModalFieldConfig.batched("video"),
vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
num_queries_vis_abstractors_videos=MultiModalFieldConfig.batched(
"video"),
num_queries_vis_abstractors_slow_videos=MultiModalFieldConfig.
batched("video"),
first_last_frames_slows_videos=MultiModalFieldConfig.batched(
"video"),
)
@@ -617,6 +577,7 @@ class HCXVisionCAbstractor(nn.Module):
info=_build_hcxvision_hf_info,
dummy_inputs=HCXVisionDummyInputsBuilder)
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -692,55 +653,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Only image or video modality is supported")
def _parse_and_validate_image_input(
self,
**kwargs: object,
) -> Optional[HCXVisionImageInputs]:
pixel_values_images = kwargs.pop("pixel_values_images", None)
if pixel_values_images is None:
return None
image_sizes_images = kwargs.pop("image_sizes_images")
return HCXVisionImagePixelInputs(
pixel_values_images=pixel_values_images,
image_sizes_images=image_sizes_images,
)
def _parse_and_validate_video_input(
self,
**kwargs: object,
) -> Optional[HCXVisionVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
if pixel_values_videos is None:
return None
return HCXVisionVideoPixelInputs(
pixel_values_videos=pixel_values_videos, )
def _process_image_input(
self,
image_input: HCXVisionImageInputs,
) -> tuple[torch.Tensor, ...]:
return self.forward_images(
pixel_values_images=image_input["pixel_values_images"],
image_sizes_images=image_input["image_sizes_images"],
)
def _process_video_input(
self,
video_input: HCXVisionVideoInputs,
) -> tuple[torch.Tensor, ...]:
return self.forward_videos(
pixel_values_videos=video_input["pixel_values_videos"], )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if (input_key == "pixel_values_images"
and "images" not in modalities):
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if (input_key == "pixel_values_videos"
and "videos" not in modalities):
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self,
**kwargs: Unpack[HCXVisionMultimodalInputs],
**kwargs: object,
) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
multimodal_embeddings = list()
if kwargs.get("pixel_values_images") is not None:
for _pixel_values_images, _image_sizes_images in zip(
kwargs["pixel_values_images"],
kwargs["image_sizes_images"]):
_pixel_values_images = _pixel_values_images.unsqueeze(dim=0)
_image_sizes_images = _image_sizes_images.unsqueeze(dim=0)
_len_pixel_values_images = [
len(pixel_value) for pixel_value in _pixel_values_images
]
if isinstance(_image_sizes_images, torch.Tensor):
_image_sizes_images = _image_sizes_images.detach().cpu(
).tolist()
_multimodal_embeddings_images = self.forward_images(
pixel_values_images=_pixel_values_images,
image_sizes_images=_image_sizes_images,
len_pixel_values_images=_len_pixel_values_images,
)
_multimodal_embeddings_images = torch.cat(
_multimodal_embeddings_images, dim=0)
multimodal_embeddings.append(_multimodal_embeddings_images)
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings
if kwargs.get("pixel_values_videos") is not None:
for _pixel_values_videos, _vision_query_lengths_videos in zip(
kwargs["pixel_values_videos"],
kwargs["vision_query_lengths_videos"]):
_len_pixel_values_videos = [
len(_vision_query_lengths)
for _vision_query_lengths in _vision_query_lengths_videos
]
_c, _w, _h = _pixel_values_videos.shape[-3:]
_pixel_values_videos = _pixel_values_videos.reshape(
sum(_len_pixel_values_videos), -1, _c, _w,
_h).unsqueeze(dim=0)
_multimodal_embeddings_videos = self.forward_videos(
pixel_values_videos=_pixel_values_videos,
len_pixel_values_videos=_len_pixel_values_videos,
)
_multimodal_embeddings_videos = torch.cat(
_multimodal_embeddings_videos, dim=0)
multimodal_embeddings.append(_multimodal_embeddings_videos)
return multimodal_embeddings
def forward(
@@ -762,28 +762,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward_images(
self,
pixel_values_images: list[list[torch.FloatTensor]],
image_sizes_images: list[list[tuple[int, int]]],
len_pixel_values_images: list[int],
) -> list[list[torch.Tensor]]:
if sum(len_pixel_values_images) == 0:
return None
concat_pixel_values_images = torch.cat(list(
chain(*pixel_values_images)),
dim=0)
pixel_values_images: list[torch.Tensor],
image_sizes_images: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
image_forward_outs = self.vision_model(
concat_pixel_values_images)[:, visual_token_idx:]
pixel_values_image_flat)[:, visual_token_idx:]
image_forward_outs = image_forward_outs.to(
dtype=self.mm_projector.dtype)
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
split_sizes = [
pixel_value.shape[0] for pixel_value in chain(*pixel_values_images)
]
split_sizes = [len(item) for item in pixel_values_images]
image_forward_outs = torch.split(image_forward_outs,
split_sizes,
dim=0)
@@ -791,10 +783,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# newline for anyres postprocessing
image_features = anyres_postprocessing(
image_forward_outs=image_forward_outs,
image_sizes=[
image_size for image_sizes in image_sizes_images
for image_size in image_sizes
],
image_sizes=image_sizes_images.tolist(),
num_queries_vis_abstractor=self.config.
num_queries_vis_abstractor_image,
unpad=self.config.unpad,
@@ -803,26 +792,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_newline=self.image_newline,
possible_resolutions=self.config.possible_resolutions,
)
return image_features
return tuple(image_features)
def forward_videos(
self,
pixel_values_videos: list[list[torch.FloatTensor]],
len_pixel_values_videos: list[int],
) -> list[torch.Tensor]:
len_video_grids = sum(len_pixel_values_videos)
if len_video_grids == 0:
return None
# Run Vision Model
concat_pixel_values_videos = torch.cat(list(
chain(*pixel_values_videos)),
dim=0)
pixel_values_videos: list[list[torch.Tensor]],
) -> tuple[torch.Tensor, ...]:
pixel_values_videos_flat = flatten_bn(
[frame for frames in pixel_values_videos for frame in frames],
concat=True,
)
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
video_forward_outs = self.vision_model(
concat_pixel_values_videos)[:, visual_token_idx:]
pixel_values_videos_flat)[:, visual_token_idx:]
video_forward_outs = video_forward_outs.to(
dtype=self.mm_projector.dtype)
@@ -905,7 +889,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) == 0, f"target_features is not empty!! {target_features}"
assert len(video_groups) == len(video_features)
return video_features
feats_per_video = [len(video) for video in pixel_values_videos]
idxs_per_video = [0, *accumulate(feats_per_video)]
return tuple(
torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]])
for i in range(len(feats_per_video)))
def _prepare_multimodal_kwargs(self, **kwargs: object):
output = defaultdict(list)
@@ -1111,15 +1099,15 @@ def reshape_and_unpad_image_features(
def anyres_postprocessing(
image_forward_outs: list[torch.FloatTensor],
image_forward_outs: list[torch.Tensor],
image_sizes: list[list[int]],
possible_resolutions: list[tuple[int, int]],
patch_size: int,
grid_size: int,
image_newline: torch.FloatTensor,
image_newline: torch.Tensor,
num_queries_vis_abstractor: int = -1,
unpad: bool = False,
) -> list[torch.FloatTensor]:
) -> list[torch.Tensor]:
height = width = grid_size // patch_size
if num_queries_vis_abstractor > 0:
@@ -1147,26 +1135,5 @@ def anyres_postprocessing(
(image_feature, image_newline[None].to(image_feature.device)),
dim=0)
new_image_features.append(image_feature)
image_features = new_image_features
return image_features
def resize_image(
image: Union[np.ndarray, PIL.Image.Image],
max_side: int = 378,
) -> np.ndarray:
image_arr = image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
width, height = image.size
cur_max_size = max(width, height)
if cur_max_size <= max_side:
return image_arr
scale = max_side / cur_max_size
width = int(width * scale)
height = int(height * scale)
image = image.resize((width, height), Image.LANCZOS)
image_arr = np.array(image)
return image_arr
return new_image_features