Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -3,14 +3,15 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union)
from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union
import torch
import torch.nn as nn
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
get_anyres_image_grid_shape,
unpad_image,
)
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -21,12 +22,22 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava)
from .llava import (
BaseLlavaMultiModalProcessor,
BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder,
LlavaLikeConfig,
LlavaMultiModalProjector,
init_vision_tower_for_llava,
)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_num_selected_vision_tokens
@@ -38,14 +49,16 @@ class LlavaNextImagePixelInputs(TensorSchema):
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})]
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}),
]
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
# This should be in `(height, width)` format.
@@ -58,12 +71,12 @@ class LlavaNextImageEmbeddingInputs(TensorSchema):
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs]
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageEmbeddingInputs]
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
@@ -71,7 +84,6 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
def get_hf_config(self) -> LlavaNextLikeConfig:
return self.ctx.get_hf_config(LlavaNextConfig)
@@ -141,12 +153,14 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
if aspect_ratio > current_aspect_ratio:
new_height = int(
round(original_height * (current_width / original_width), 7))
round(original_height * (current_width / original_width), 7)
)
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
new_width = int(
round(original_width * (current_height / original_height), 7))
round(original_width * (current_height / original_height), 7)
)
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
@@ -159,13 +173,13 @@ class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
hf_config = self.get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self.get_num_image_tokens(image_width=width,
image_height=height)
for height, width in hf_config.image_grid_pinpoints:
feat_size = self.get_num_image_tokens(
image_width=width, image_height=height
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
largest_feature_pinpoint = ImageSize(width=width, height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
@@ -177,7 +191,6 @@ _I = TypeVar("_I", bound=LlavaNextProcessingInfo)
class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
@@ -189,8 +202,8 @@ class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
class LlavaNextMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]
):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@@ -203,12 +216,12 @@ class LlavaNextMultiModalProcessor(
)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
info=LlavaNextProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
@MULTIMODAL_REGISTRY.register_processor(
LlavaNextMultiModalProcessor,
info=LlavaNextProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder,
)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
@@ -217,7 +230,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
})
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -240,12 +254,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
# Used for multimodal granite models to control encoder outputs
elif isinstance(vision_feature_layer, (list, tuple)):
vision_hidden_size = config.vision_config.hidden_size * len(
vision_feature_layer)
vision_feature_layer
)
self.select_layers = vision_feature_layer
else:
raise TypeError(
f"vision_layer_feature type: {type(vision_feature_layer)}"
" is not supported")
" is not supported"
)
self.config = config
self.multimodal_config = multimodal_config
@@ -255,14 +271,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=vision_hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias)
multimodal_projector_bias=config.multimodal_projector_bias,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@@ -271,10 +288,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
self, **kwargs: object
) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
@@ -284,12 +303,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
raise ValueError(
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
)
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextImagePixelInputs(
@@ -299,12 +320,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
},
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
raise ValueError(
f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
)
return LlavaNextImageEmbeddingInputs(
type="image_embeds",
@@ -327,21 +350,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
patch_embeddings: torch.Tensor, *,
strategy: str) -> torch.Tensor:
def _merge_image_patch_embeddings(
self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str
) -> torch.Tensor:
if strategy == "flat":
return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"):
height = width = self.config.vision_config.image_size \
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
base_patch_embeds = patch_embeddings[0]
if height * width != base_patch_embeds.shape[0]:
raise ValueError(
"The number of patches is not consistent with the "
"image size.")
"The number of patches is not consistent with the image size."
)
if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]
@@ -358,37 +383,51 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
num_patches = num_patch_height * num_patch_width
# Image patches might be padded for batch processing
other_patch_embeds = other_patch_embeds[:num_patches] \
.view(num_patch_height, num_patch_width, height, width, -1)
other_patch_embeds = other_patch_embeds[:num_patches].view(
num_patch_height, num_patch_width, height, width, -1
)
if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \
.permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds,
(orig_height, orig_width))
other_patch_embeds = torch.cat((
other_patch_embeds,
self.image_newline[:, None, None] \
.expand(*other_patch_embeds.shape[:-1], 1) \
other_patch_embeds = (
other_patch_embeds.permute(4, 0, 2, 1, 3)
.contiguous()
.flatten(1, 2)
.flatten(2, 3)
)
other_patch_embeds = unpad_image(
other_patch_embeds, (orig_height, orig_width)
)
other_patch_embeds = torch.cat(
(
other_patch_embeds,
self.image_newline[:, None, None]
.expand(*other_patch_embeds.shape[:-1], 1)
.to(other_patch_embeds.device),
), dim=-1)
other_patch_embeds = other_patch_embeds \
.flatten(1, 2).transpose(0, 1)
),
dim=-1,
)
other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose(
0, 1
)
else:
other_patch_embeds = other_patch_embeds \
.permute(0, 2, 1, 3, 4).contiguous() \
other_patch_embeds = (
other_patch_embeds.permute(0, 2, 1, 3, 4)
.contiguous()
.flatten(0, 3)
)
merged_patch_embeddings = torch.cat(
(base_patch_embeds, other_patch_embeds), dim=0)
(base_patch_embeds, other_patch_embeds), dim=0
)
else:
if "unpad" in strategy:
merged_patch_embeddings = torch.cat(
(base_patch_embeds,
self.image_newline[None] \
.to(base_patch_embeds.device)
), dim=0)
(
base_patch_embeds,
self.image_newline[None].to(base_patch_embeds.device),
),
dim=0,
)
else:
merged_patch_embeddings = base_patch_embeds
@@ -408,20 +447,25 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
self.vision_tower, stacked_pixel_values
)
stacked_patch_embeddings = self.multi_modal_projector(
stacked_image_features)
stacked_image_features
)
return stacked_patch_embeddings.view(
b, num_patches, *stacked_patch_embeddings.shape[1:])
b, num_patches, *stacked_patch_embeddings.shape[1:]
)
num_patches_per_batch = [v.shape[0] for v in pixel_values]
stacked_pixel_values = torch.cat(pixel_values)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
self.vision_tower, stacked_pixel_values
)
return torch.split(self.multi_modal_projector(stacked_image_features),
num_patches_per_batch)
return torch.split(
self.multi_modal_projector(stacked_image_features), num_patches_per_batch
)
def _process_image_input(
self,
@@ -437,21 +481,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
batch_size = len(image_input["data"])
vision_config = self.config.vision_config
default_height = default_width = vision_config.image_size
image_sizes = torch.as_tensor([[default_height, default_width]
for _ in range(batch_size)])
image_sizes = torch.as_tensor(
[[default_height, default_width] for _ in range(batch_size)]
)
return [
self._merge_image_patch_embeddings(image_sizes[i],
patch_features_batch,
strategy="spatial_unpad")
self._merge_image_patch_embeddings(
image_sizes[i], patch_features_batch, strategy="spatial_unpad"
)
for i, patch_features_batch in enumerate(patch_embeddings)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
@@ -535,10 +579,9 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
@@ -547,7 +590,6 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)