[VLM] Various cleanup and fixes (#14806)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -42,7 +42,7 @@ _MAX_FRAMES_PER_VIDEO = 16
|
||||
|
||||
class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)`
|
||||
|
||||
@@ -54,7 +54,7 @@ class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
|
||||
class LlavaOnevisionImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
||||
@@ -521,7 +521,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return LlavaOnevisionImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_pixel_values(
|
||||
pixel_values=self._validate_image_pixel_values(
|
||||
flatten_bn(pixel_values)),
|
||||
image_sizes=self._validate_image_sizes(
|
||||
flatten_bn(image_sizes, concat=True)),
|
||||
@@ -570,21 +570,20 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
List[b, Tensor(nb_frames, nb_channels, height, width)]
|
||||
}
|
||||
"""
|
||||
pixel_values = kwargs.pop("pixel_values_videos", None)
|
||||
|
||||
if pixel_values is None:
|
||||
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
||||
if pixel_values_videos is None:
|
||||
return None
|
||||
|
||||
if not (is_list_of(pixel_values,
|
||||
(torch.Tensor)) # different shape videos
|
||||
or isinstance(pixel_values,
|
||||
if not (is_list_of(pixel_values_videos,
|
||||
torch.Tensor) # different shape videos
|
||||
or isinstance(pixel_values_videos,
|
||||
torch.Tensor)): # same shape videos
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
raise ValueError("Incorrect type of pixel_values_videos. "
|
||||
f"Got type: {type(pixel_values_videos)}")
|
||||
|
||||
return LlavaOnevisionVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
data=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
@@ -723,7 +722,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
@@ -757,7 +756,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
image_sizes = image_input.get("image_sizes")
|
||||
if image_sizes is None:
|
||||
batch_size = len(image_input["data"])
|
||||
batch_size = len(image_input["pixel_values"])
|
||||
vision_config = self.config.vision_config
|
||||
default_height = default_width = vision_config.image_size
|
||||
image_sizes = torch.as_tensor([[default_height, default_width]
|
||||
@@ -808,7 +807,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
|
||||
assert self.vision_tower is not None
|
||||
|
||||
video_pixels = inputs["data"]
|
||||
video_pixels = inputs["pixel_values_videos"]
|
||||
|
||||
if isinstance(video_pixels, torch.Tensor):
|
||||
b, num_videos, frames, c, h, w = video_pixels.shape
|
||||
|
||||
Reference in New Issue
Block a user