Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,14 +3,15 @@
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
|
||||
Union)
|
||||
from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -21,12 +22,22 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
|
||||
LlavaDummyInputsBuilder, LlavaLikeConfig,
|
||||
LlavaMultiModalProjector, init_vision_tower_for_llava)
|
||||
from .llava import (
|
||||
BaseLlavaMultiModalProcessor,
|
||||
BaseLlavaProcessingInfo,
|
||||
LlavaDummyInputsBuilder,
|
||||
LlavaLikeConfig,
|
||||
LlavaMultiModalProjector,
|
||||
init_vision_tower_for_llava,
|
||||
)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import get_num_selected_vision_tokens
|
||||
|
||||
|
||||
@@ -38,14 +49,16 @@ class LlavaNextImagePixelInputs(TensorSchema):
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
|
||||
|
||||
Note that `num_patches` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
pixel_values: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})]
|
||||
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}),
|
||||
]
|
||||
|
||||
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
|
||||
# This should be in `(height, width)` format.
|
||||
@@ -58,12 +71,12 @@ class LlavaNextImageEmbeddingInputs(TensorSchema):
|
||||
- ifs: Image feature size
|
||||
- hs: Hidden size (must match language model backbone)
|
||||
"""
|
||||
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
|
||||
|
||||
|
||||
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||
LlavaNextImageEmbeddingInputs]
|
||||
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageEmbeddingInputs]
|
||||
|
||||
|
||||
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
|
||||
@@ -71,7 +84,6 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
|
||||
|
||||
|
||||
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
|
||||
|
||||
def get_hf_config(self) -> LlavaNextLikeConfig:
|
||||
return self.ctx.get_hf_config(LlavaNextConfig)
|
||||
|
||||
@@ -141,12 +153,14 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
|
||||
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = int(
|
||||
round(original_height * (current_width / original_width), 7))
|
||||
round(original_height * (current_width / original_width), 7)
|
||||
)
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height = current_height - (2 * padding)
|
||||
else:
|
||||
new_width = int(
|
||||
round(original_width * (current_height / original_height), 7))
|
||||
round(original_width * (current_height / original_height), 7)
|
||||
)
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width = current_width - (2 * padding)
|
||||
|
||||
@@ -159,13 +173,13 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
|
||||
hf_config = self.get_hf_config()
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for (height, width) in hf_config.image_grid_pinpoints:
|
||||
feat_size = self.get_num_image_tokens(image_width=width,
|
||||
image_height=height)
|
||||
for height, width in hf_config.image_grid_pinpoints:
|
||||
feat_size = self.get_num_image_tokens(
|
||||
image_width=width, image_height=height
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
height=height)
|
||||
largest_feature_pinpoint = ImageSize(width=width, height=height)
|
||||
|
||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
@@ -177,7 +191,6 @@ _I = TypeVar("_I", bound=LlavaNextProcessingInfo)
|
||||
|
||||
|
||||
class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
|
||||
|
||||
# Copied from BaseMultiModalProcessor
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
@@ -189,8 +202,8 @@ class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
|
||||
|
||||
|
||||
class LlavaNextMultiModalProcessor(
|
||||
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
|
||||
|
||||
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]
|
||||
):
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
@@ -203,12 +216,12 @@ class LlavaNextMultiModalProcessor(
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
|
||||
info=LlavaNextProcessingInfo,
|
||||
dummy_inputs=LlavaDummyInputsBuilder)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
LlavaNextMultiModalProcessor,
|
||||
info=LlavaNextProcessingInfo,
|
||||
dummy_inputs=LlavaDummyInputsBuilder,
|
||||
)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
@@ -217,7 +230,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"model.image_newline": "image_newline",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@@ -240,12 +254,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# Used for multimodal granite models to control encoder outputs
|
||||
elif isinstance(vision_feature_layer, (list, tuple)):
|
||||
vision_hidden_size = config.vision_config.hidden_size * len(
|
||||
vision_feature_layer)
|
||||
vision_feature_layer
|
||||
)
|
||||
self.select_layers = vision_feature_layer
|
||||
else:
|
||||
raise TypeError(
|
||||
f"vision_layer_feature type: {type(vision_feature_layer)}"
|
||||
" is not supported")
|
||||
" is not supported"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
@@ -255,14 +271,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=vision_hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
multimodal_projector_bias=config.multimodal_projector_bias)
|
||||
multimodal_projector_bias=config.multimodal_projector_bias,
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
@@ -271,10 +288,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[LlavaNextImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
@@ -284,12 +303,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(image_sizes, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
|
||||
)
|
||||
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
return LlavaNextImagePixelInputs(
|
||||
@@ -299,12 +320,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeds. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
|
||||
)
|
||||
|
||||
return LlavaNextImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
@@ -327,21 +350,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
||||
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
|
||||
patch_embeddings: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
def _merge_image_patch_embeddings(
|
||||
self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str
|
||||
) -> torch.Tensor:
|
||||
if strategy == "flat":
|
||||
return patch_embeddings.flatten(0, 1)
|
||||
|
||||
if strategy.startswith("spatial"):
|
||||
height = width = self.config.vision_config.image_size \
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
base_patch_embeds = patch_embeddings[0]
|
||||
if height * width != base_patch_embeds.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the "
|
||||
"image size.")
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
|
||||
if patch_embeddings.shape[0] > 1:
|
||||
other_patch_embeds = patch_embeddings[1:]
|
||||
@@ -358,37 +383,51 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
num_patches = num_patch_height * num_patch_width
|
||||
|
||||
# Image patches might be padded for batch processing
|
||||
other_patch_embeds = other_patch_embeds[:num_patches] \
|
||||
.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
other_patch_embeds = other_patch_embeds[:num_patches].view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
|
||||
if "unpad" in strategy:
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.permute(4, 0, 2, 1, 3).contiguous() \
|
||||
.flatten(1, 2).flatten(2, 3)
|
||||
other_patch_embeds = unpad_image(other_patch_embeds,
|
||||
(orig_height, orig_width))
|
||||
other_patch_embeds = torch.cat((
|
||||
other_patch_embeds,
|
||||
self.image_newline[:, None, None] \
|
||||
.expand(*other_patch_embeds.shape[:-1], 1) \
|
||||
other_patch_embeds = (
|
||||
other_patch_embeds.permute(4, 0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.flatten(1, 2)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
other_patch_embeds = unpad_image(
|
||||
other_patch_embeds, (orig_height, orig_width)
|
||||
)
|
||||
other_patch_embeds = torch.cat(
|
||||
(
|
||||
other_patch_embeds,
|
||||
self.image_newline[:, None, None]
|
||||
.expand(*other_patch_embeds.shape[:-1], 1)
|
||||
.to(other_patch_embeds.device),
|
||||
), dim=-1)
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.flatten(1, 2).transpose(0, 1)
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose(
|
||||
0, 1
|
||||
)
|
||||
else:
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.permute(0, 2, 1, 3, 4).contiguous() \
|
||||
other_patch_embeds = (
|
||||
other_patch_embeds.permute(0, 2, 1, 3, 4)
|
||||
.contiguous()
|
||||
.flatten(0, 3)
|
||||
)
|
||||
|
||||
merged_patch_embeddings = torch.cat(
|
||||
(base_patch_embeds, other_patch_embeds), dim=0)
|
||||
(base_patch_embeds, other_patch_embeds), dim=0
|
||||
)
|
||||
else:
|
||||
if "unpad" in strategy:
|
||||
merged_patch_embeddings = torch.cat(
|
||||
(base_patch_embeds,
|
||||
self.image_newline[None] \
|
||||
.to(base_patch_embeds.device)
|
||||
), dim=0)
|
||||
(
|
||||
base_patch_embeds,
|
||||
self.image_newline[None].to(base_patch_embeds.device),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
merged_patch_embeddings = base_patch_embeds
|
||||
|
||||
@@ -408,20 +447,25 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
self.vision_tower, stacked_pixel_values
|
||||
)
|
||||
stacked_patch_embeddings = self.multi_modal_projector(
|
||||
stacked_image_features)
|
||||
stacked_image_features
|
||||
)
|
||||
|
||||
return stacked_patch_embeddings.view(
|
||||
b, num_patches, *stacked_patch_embeddings.shape[1:])
|
||||
b, num_patches, *stacked_patch_embeddings.shape[1:]
|
||||
)
|
||||
|
||||
num_patches_per_batch = [v.shape[0] for v in pixel_values]
|
||||
stacked_pixel_values = torch.cat(pixel_values)
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
self.vision_tower, stacked_pixel_values
|
||||
)
|
||||
|
||||
return torch.split(self.multi_modal_projector(stacked_image_features),
|
||||
num_patches_per_batch)
|
||||
return torch.split(
|
||||
self.multi_modal_projector(stacked_image_features), num_patches_per_batch
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
@@ -437,21 +481,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
batch_size = len(image_input["data"])
|
||||
vision_config = self.config.vision_config
|
||||
default_height = default_width = vision_config.image_size
|
||||
image_sizes = torch.as_tensor([[default_height, default_width]
|
||||
for _ in range(batch_size)])
|
||||
image_sizes = torch.as_tensor(
|
||||
[[default_height, default_width] for _ in range(batch_size)]
|
||||
)
|
||||
|
||||
return [
|
||||
self._merge_image_patch_embeddings(image_sizes[i],
|
||||
patch_features_batch,
|
||||
strategy="spatial_unpad")
|
||||
self._merge_image_patch_embeddings(
|
||||
image_sizes[i], patch_features_batch, strategy="spatial_unpad"
|
||||
)
|
||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@@ -535,10 +579,9 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@@ -547,7 +590,6 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Reference in New Issue
Block a user