[VLM] Various cleanup and fixes (#14806)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -3,8 +3,8 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
|
||||
TypedDict, TypeVar, Union, cast)
|
||||
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
|
||||
TypeVar, Union, cast)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -39,8 +39,7 @@ from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel,
|
||||
get_pixtral_hf_image_feature_grid_size)
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
@@ -49,7 +48,7 @@ from .vision import get_vision_encoder_info
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
pixel_values: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
|
||||
@@ -57,7 +56,18 @@ class LlavaImagePixelInputs(TypedDict):
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
|
||||
class PixtralHFImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_pixtral"]
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
|
||||
Note that `height` or `width` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image features correspond
|
||||
to patch tokens.
|
||||
@@ -65,7 +75,7 @@ class LlavaImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
@@ -73,7 +83,7 @@ class LlavaImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: torch.Tensor
|
||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
@@ -85,27 +95,9 @@ class LlavaImageEmbeddingInputs(TypedDict):
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image features correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: torch.Tensor
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
|
||||
LlavaImageEmbeddingInputs]
|
||||
|
||||
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
@@ -357,13 +349,15 @@ class PixtralHFMultiModalProcessor(
|
||||
]
|
||||
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
encoder_info = PixtralHFEncoderInfo(vision_config)
|
||||
|
||||
tile_sizes = [
|
||||
get_pixtral_hf_image_feature_grid_size(
|
||||
hf_config.vision_config,
|
||||
encoder_info.get_patch_grid_size(
|
||||
image_width=pixel_value.shape[-1],
|
||||
image_height=pixel_value.shape[-2])
|
||||
for pixel_value in processed_outputs["pixel_values"]
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
num_crops = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
@@ -411,13 +405,13 @@ class PixtralHFMultiModalProcessor(
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
encoder_info = PixtralHFEncoderInfo(vision_config)
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
|
||||
vision_config,
|
||||
ncols, nrows = encoder_info.get_patch_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
@@ -512,7 +506,7 @@ def init_vision_tower_for_llava(
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Initialize the vision tower only up to the deepest required feature layer
|
||||
@@ -627,32 +621,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
feat_is_patch = kwargs.pop("feat_is_patch", None)
|
||||
if feat_is_patch is not None and not isinstance(
|
||||
feat_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of feat_is_patch. "
|
||||
f"Got type: {type(feat_is_patch)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
if embed_is_patch is not None and not isinstance(
|
||||
embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if self.config.vision_config.model_type == "pixtral":
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=flatten_bn(pixel_values),
|
||||
feat_is_patch = kwargs.pop("feat_is_patch")
|
||||
if not isinstance(feat_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of feat_is_patch. "
|
||||
f"Got type: {type(feat_is_patch)}")
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_crops = kwargs.pop("num_crops")
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
return PixtralHFImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
feat_is_patch=feat_is_patch,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_crops=num_crops,
|
||||
@@ -660,11 +652,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
pixel_values=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
feat_is_patch=feat_is_patch,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
@@ -672,12 +661,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
if self.config.vision_config.model_type == "pixtral":
|
||||
raise ValueError("Pixtral-HF does not support image_embeds.")
|
||||
|
||||
return LlavaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
feat_is_patch=feat_is_patch,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
@@ -696,7 +685,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
||||
PixtralHFVisionModel],
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
@@ -708,17 +697,20 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
def _process_image_pixels(self,
|
||||
inputs: LlavaImagePixelInputs) -> torch.Tensor:
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
|
||||
) -> torch.Tensor:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: LlavaImageInputs) -> torch.Tensor:
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: LlavaImageInputs,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@@ -783,11 +775,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
|
||||
if kwargs.get("v0_path", False) or \
|
||||
image_input.get("feat_is_patch") is None or \
|
||||
image_input.get("embed_is_patch") is None:
|
||||
if (kwargs.get("v0_path", False)
|
||||
or image_input["type"] != "pixel_values_pixtral"):
|
||||
# The path is used for pixtral (V0 only) and llava (V0/V1)
|
||||
return vision_embeddings
|
||||
|
||||
|
||||
Reference in New Issue
Block a user