[VLM] Various cleanup and fixes (#14806)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-14 20:58:19 +08:00
committed by GitHub
parent 40253bab44
commit ab93f1360f
14 changed files with 283 additions and 273 deletions

View File

@@ -3,8 +3,8 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, TypeVar, Union, cast)
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union, cast)
import torch
import torch.nn as nn
@@ -39,8 +39,7 @@ from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel,
get_pixtral_hf_image_feature_grid_size)
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@@ -49,7 +48,7 @@ from .vision import get_vision_encoder_info
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
pixel_values: torch.Tensor
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
@@ -57,7 +56,18 @@ class LlavaImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
class PixtralHFImagePixelInputs(TypedDict):
type: Literal["pixel_values_pixtral"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
@@ -65,7 +75,7 @@ class LlavaImagePixelInputs(TypedDict):
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
@@ -73,7 +83,7 @@ class LlavaImagePixelInputs(TypedDict):
Shape: `(batch_size, num_embeds)`
"""
num_crops: torch.Tensor
num_crops: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
@@ -85,27 +95,9 @@ class LlavaImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_embeds)`
"""
num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
LlavaImageEmbeddingInputs]
class LlavaMultiModalProjector(nn.Module):
@@ -357,13 +349,15 @@ class PixtralHFMultiModalProcessor(
]
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
tile_sizes = [
get_pixtral_hf_image_feature_grid_size(
hf_config.vision_config,
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2])
for pixel_value in processed_outputs["pixel_values"]
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_crops = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
@@ -411,13 +405,13 @@ class PixtralHFMultiModalProcessor(
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
vision_config,
ncols, nrows = encoder_info.get_patch_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
@@ -512,7 +506,7 @@ def init_vision_tower_for_llava(
*,
require_post_norm: Optional[bool] = None,
prefix: str = "",
):
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the deepest required feature layer
@@ -627,32 +621,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if pixel_values is None and image_embeds is None:
return None
feat_is_patch = kwargs.pop("feat_is_patch", None)
if feat_is_patch is not None and not isinstance(
feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch", None)
if embed_is_patch is not None and not isinstance(
embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops", None)
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral":
return LlavaImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values),
feat_is_patch = kwargs.pop("feat_is_patch")
if not isinstance(feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
@@ -660,11 +652,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)
if image_embeds is not None:
@@ -672,12 +661,12 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
if self.config.vision_config.model_type == "pixtral":
raise ValueError("Pixtral-HF does not support image_embeds.")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)
raise AssertionError("This line should be unreachable.")
@@ -696,7 +685,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: torch.Tensor,
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
@@ -708,17 +697,20 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
strategy=self.config.vision_feature_select_strategy,
)
def _process_image_pixels(self,
inputs: LlavaImagePixelInputs) -> torch.Tensor:
def _process_image_pixels(
self,
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"]
pixel_values = inputs["pixel_values"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
def _process_image_input(
self,
image_input: LlavaImageInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@@ -783,11 +775,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
if kwargs.get("v0_path", False) or \
image_input.get("feat_is_patch") is None or \
image_input.get("embed_is_patch") is None:
if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values_pixtral"):
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings