Migrate Interns1 inputs to TensorSchema (#23510)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck
2025-09-01 21:35:45 -07:00
committed by GitHub
parent 7be0cb8e9e
commit 56d04089ef

View File

@@ -7,7 +7,7 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import regex as re import regex as re
import torch import torch
@@ -32,6 +32,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
@@ -62,51 +63,60 @@ class InternS1MultiModalProjector(nn.Module):
return hidden_states return hidden_states
class InternS1ImagePixelInputs(TypedDict): class InternS1ImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
""" """
Shape: Dimensions:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)` - bnp: Batch size * number of images * (1 + num_patches)
- c: Number of channels (3)
- h: Height
- w: Width
- bn: Batch size * number of images
""" """
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class InternS1ImageEmbeddingInputs(TypedDict): class InternS1ImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
""" """
A tensor of shape `(num_images, total_image_feature_size, hidden_size)` Dimensions:
or a list of tensors of shape `(total_image_feature_size, hidden_size)` - ni: Number of images
- tifs: Total image feature size
`hidden_size` must match the hidden size of language model backbone. - hs: Hidden size (must match language model backbone)
""" """
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("ni", "tifs", "hs")]
InternS1ImageInputs = Union[InternS1ImagePixelInputs, InternS1ImageInputs = Union[InternS1ImagePixelInputs,
InternS1ImageEmbeddingInputs] InternS1ImageEmbeddingInputs]
class InternS1VideoPixelInputs(TypedDict): class InternS1VideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"]
pixel_values: torch.Tensor
""" """
Shape: Dimensions:
`(batch_size * num_video * num_frames, num_channels, height, width)` - bnv: Batch size * number of videos * number of frames
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
""" """
type: Literal["pixel_values_videos"] = "pixel_values_videos"
num_patches: torch.Tensor pixel_values: Annotated[torch.Tensor, TensorShape("bnv", 3, "h", "w")]
"""Shape: `(batch_size * num_images)`""" num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class InternS1VideoEmbeddingInputs(TypedDict): class InternS1VideoEmbeddingInputs(TensorSchema):
type: Literal["video_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
""" """
A tensor of shape `(num_videos, total_video_feature_size, hidden_size)` Dimensions:
or a list of tensors of shape `(total_video_feature_size, hidden_size)` - nv: Number of videos
- tvfs: Total video feature size
`hidden_size` must match the hidden size of language model backbone. - hs: Hidden size (must match language model backbone)
""" """
type: Literal["video_embeds"] = "video_embeds"
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("nv", "tvfs", "hs")]
InternS1VideoInputs = Union[InternS1VideoPixelInputs, InternS1VideoInputs = Union[InternS1VideoPixelInputs,
@@ -572,26 +582,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
vit_embeds = self.multi_modal_projector(vit_embeds) vit_embeds = self.multi_modal_projector(vit_embeds)
return vit_embeds return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h, w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternS1ImageInputs]: self, **kwargs: object) -> Optional[InternS1ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
@@ -627,10 +617,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = flatten_bn(pixel_values, concat=True) pixel_values = flatten_bn(pixel_values, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True)
h, w = self.config.vision_config.image_size
return InternS1ImagePixelInputs( return InternS1ImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=pixel_values,
num_patches=image_num_patches, num_patches=image_num_patches,
resolve_bindings={
"h": h,
"w": w,
},
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
@@ -671,11 +666,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
concat=True) concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True) video_num_patches = flatten_bn(video_num_patches, concat=True)
h, w = self.config.vision_config.image_size
return InternS1VideoPixelInputs( return InternS1VideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
pixel_values=self._validate_pixel_values(
pixel_values_flat_video),
num_patches=video_num_patches, num_patches=video_num_patches,
pixel_values=pixel_values_flat_video,
resolve_bindings={
"h": h,
"w": w,
},
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")