Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -7,21 +7,30 @@ from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
from transformers import (BatchFeature, LlavaNextVideoConfig,
LlavaNextVideoProcessor)
from transformers import BatchFeature, LlavaNextVideoConfig, LlavaNextVideoProcessor
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
ImageSize,
MultiModalDataItems,
VideoEmbeddingItems,
VideoProcessorItems,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@@ -30,13 +39,17 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vision_encoder_info
class LlavaNextVideoPixelInputs(TensorSchema):
"""
"""
Dimensions:
- bs: Batch size
- nv: Number of videos
@@ -50,14 +63,16 @@ class LlavaNextVideoPixelInputs(TensorSchema):
Note that it only supports one video input for one batch.
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bs", "nv", "nf", 3, "h", "w")]
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bs", "nv", "nf", 3, "h", "w"),
]
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(LlavaNextVideoConfig)
@@ -137,8 +152,8 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
class LlavaNextVideoDummyInputsBuilder(
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]
):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_videos = mm_counts.get("video", 0)
@@ -155,16 +170,15 @@ class LlavaNextVideoDummyInputsBuilder(
) -> MultiModalDataDict:
num_videos = mm_counts.get("video", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
target_width, target_height = self.info.get_image_size_with_most_features()
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts
)
video_overrides = mm_options.get("video") if mm_options else None
return {
"video":
self._get_dummy_videos(
"video": self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=target_num_frames,
@@ -175,8 +189,8 @@ class LlavaNextVideoDummyInputsBuilder(
class LlavaNextVideoMultiModalProcessor(
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]
):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@@ -195,7 +209,8 @@ class LlavaNextVideoMultiModalProcessor(
def get_replacement(item_idx: int):
videos = mm_items.get_items(
"video", (VideoEmbeddingItems, VideoProcessorItems))
"video", (VideoEmbeddingItems, VideoProcessorItems)
)
if isinstance(videos, VideoEmbeddingItems):
num_video_tokens = videos.get_feature_size(item_idx)
@@ -220,7 +235,6 @@ class LlavaNextVideoMultiModalProcessor(
# adopted from transformers modeling_llava_next_video.py
class LlavaNextVideoPooler(nn.Module):
def __init__(self, config: LlavaNextVideoConfig):
super().__init__()
@@ -237,36 +251,41 @@ class LlavaNextVideoPooler(nn.Module):
else:
# TODO: Support Conv2d pooling layer, need to load weights
raise ValueError(
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]")
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]"
)
def forward(self, image_features: torch.Tensor):
ori_width = int(
math.sqrt(image_features.shape[1] * self.image_size //
self.image_size))
math.sqrt(image_features.shape[1] * self.image_size // self.image_size)
)
ori_height = int(ori_width * self.image_size // self.image_size)
batch_size, _, dim = image_features.shape
image_features_spatial = image_features \
.view(batch_size, ori_height, ori_height, dim) \
.permute(0, 3, 1, 2)
image_features_spatial = image_features.view(
batch_size, ori_height, ori_height, dim
).permute(0, 3, 1, 2)
image_features_spatial = self.pool(image_features_spatial)
return image_features_spatial.flatten(2).transpose(1, 2).contiguous()
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str, multimodal_projector_bias: bool):
def __init__(
self,
vision_hidden_size: int,
text_hidden_size: int,
projector_hidden_act: str,
multimodal_projector_bias: bool,
):
super().__init__()
self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias)
self.linear_1 = nn.Linear(
vision_hidden_size, text_hidden_size, bias=multimodal_projector_bias
)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias)
self.linear_2 = nn.Linear(
text_hidden_size, text_hidden_size, bias=multimodal_projector_bias
)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
@@ -280,9 +299,7 @@ class LlavaNextMultiModalProjector(nn.Module):
info=LlavaNextVideoProcessingInfo,
dummy_inputs=LlavaNextVideoDummyInputsBuilder,
)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
@@ -291,7 +308,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
})
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -316,13 +334,15 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias)
multimodal_projector_bias=config.multimodal_projector_bias,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
@@ -330,14 +350,16 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
)
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
self.language_model.model.make_empty_intermediate_tensors
)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
self, **kwargs: object
) -> Optional[LlavaNextVideoPixelInputs]:
"""
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
"pixel_values_videos" :
list[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
@@ -347,12 +369,14 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextVideoPixelInputs(type="pixel_values_videos",
data=pixel_values_videos,
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
return LlavaNextVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values_videos,
resolve_bindings={
"h": expected_h,
"w": expected_w,
},
)
def _video_pixels_to_features(
self,
@@ -377,31 +401,31 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
if isinstance(video_pixels, torch.Tensor):
# TODO: support multiple videos per input
b, num_videos, num_frames, c, h, w = video_pixels.shape
assert (num_videos == 1)
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c,
h, w)
assert num_videos == 1
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c, h, w)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels)
embeds = stacked_embeddings.view(b, num_frames,
*stacked_embeddings.shape[1:])
self.vision_tower, stacked_pixels
)
embeds = stacked_embeddings.view(
b, num_frames, *stacked_embeddings.shape[1:]
)
elif is_list_of(video_pixels, torch.Tensor):
frames_per_videos = [v.shape[0] for v in video_pixels]
stacked_pixels = torch.cat(video_pixels, dim=0)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels)
self.vision_tower, stacked_pixels
)
embeds = torch.split(stacked_embeddings, frames_per_videos, dim=0)
else:
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
raise ValueError(f"Unsupported type of video input {type(video_pixels)}")
return [e.flatten(0, 1) for e in embeds]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None:
return []
@@ -425,10 +449,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
@@ -438,8 +461,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
# This model doesn't support images for now