Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,21 +7,30 @@ from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, LlavaNextVideoConfig,
|
||||
LlavaNextVideoProcessor)
|
||||
from transformers import BatchFeature, LlavaNextVideoConfig, LlavaNextVideoProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
VideoEmbeddingItems, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
VideoEmbeddingItems,
|
||||
VideoProcessorItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
@@ -30,13 +39,17 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
class LlavaNextVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
"""
|
||||
Dimensions:
|
||||
- bs: Batch size
|
||||
- nv: Number of videos
|
||||
@@ -50,14 +63,16 @@ class LlavaNextVideoPixelInputs(TensorSchema):
|
||||
|
||||
Note that it only supports one video input for one batch.
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
||||
|
||||
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bs", "nv", "nf", 3, "h", "w")]
|
||||
data: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bs", "nv", "nf", 3, "h", "w"),
|
||||
]
|
||||
|
||||
|
||||
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
|
||||
@@ -137,8 +152,8 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
|
||||
class LlavaNextVideoDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
|
||||
|
||||
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]
|
||||
):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
@@ -155,16 +170,15 @@ class LlavaNextVideoDummyInputsBuilder(
|
||||
) -> MultiModalDataDict:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
target_num_frames = self.info.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts
|
||||
)
|
||||
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
"video": self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
@@ -175,8 +189,8 @@ class LlavaNextVideoDummyInputsBuilder(
|
||||
|
||||
|
||||
class LlavaNextVideoMultiModalProcessor(
|
||||
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
|
||||
|
||||
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]
|
||||
):
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
@@ -195,7 +209,8 @@ class LlavaNextVideoMultiModalProcessor(
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
videos = mm_items.get_items(
|
||||
"video", (VideoEmbeddingItems, VideoProcessorItems))
|
||||
"video", (VideoEmbeddingItems, VideoProcessorItems)
|
||||
)
|
||||
|
||||
if isinstance(videos, VideoEmbeddingItems):
|
||||
num_video_tokens = videos.get_feature_size(item_idx)
|
||||
@@ -220,7 +235,6 @@ class LlavaNextVideoMultiModalProcessor(
|
||||
|
||||
# adopted from transformers modeling_llava_next_video.py
|
||||
class LlavaNextVideoPooler(nn.Module):
|
||||
|
||||
def __init__(self, config: LlavaNextVideoConfig):
|
||||
super().__init__()
|
||||
|
||||
@@ -237,36 +251,41 @@ class LlavaNextVideoPooler(nn.Module):
|
||||
else:
|
||||
# TODO: Support Conv2d pooling layer, need to load weights
|
||||
raise ValueError(
|
||||
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]")
|
||||
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]"
|
||||
)
|
||||
|
||||
def forward(self, image_features: torch.Tensor):
|
||||
ori_width = int(
|
||||
math.sqrt(image_features.shape[1] * self.image_size //
|
||||
self.image_size))
|
||||
math.sqrt(image_features.shape[1] * self.image_size // self.image_size)
|
||||
)
|
||||
ori_height = int(ori_width * self.image_size // self.image_size)
|
||||
|
||||
batch_size, _, dim = image_features.shape
|
||||
image_features_spatial = image_features \
|
||||
.view(batch_size, ori_height, ori_height, dim) \
|
||||
.permute(0, 3, 1, 2)
|
||||
image_features_spatial = image_features.view(
|
||||
batch_size, ori_height, ori_height, dim
|
||||
).permute(0, 3, 1, 2)
|
||||
image_features_spatial = self.pool(image_features_spatial)
|
||||
|
||||
return image_features_spatial.flatten(2).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
|
||||
projector_hidden_act: str, multimodal_projector_bias: bool):
|
||||
def __init__(
|
||||
self,
|
||||
vision_hidden_size: int,
|
||||
text_hidden_size: int,
|
||||
projector_hidden_act: str,
|
||||
multimodal_projector_bias: bool,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(vision_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=multimodal_projector_bias)
|
||||
self.linear_1 = nn.Linear(
|
||||
vision_hidden_size, text_hidden_size, bias=multimodal_projector_bias
|
||||
)
|
||||
self.act = get_act_fn(projector_hidden_act)
|
||||
self.linear_2 = nn.Linear(text_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=multimodal_projector_bias)
|
||||
self.linear_2 = nn.Linear(
|
||||
text_hidden_size, text_hidden_size, bias=multimodal_projector_bias
|
||||
)
|
||||
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.linear_1(image_features)
|
||||
@@ -280,9 +299,7 @@ class LlavaNextMultiModalProjector(nn.Module):
|
||||
info=LlavaNextVideoProcessingInfo,
|
||||
dummy_inputs=LlavaNextVideoDummyInputsBuilder,
|
||||
)
|
||||
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
@@ -291,7 +308,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"model.image_newline": "image_newline",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@@ -316,13 +334,15 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.vision_resampler = LlavaNextVideoPooler(config)
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
multimodal_projector_bias=config.multimodal_projector_bias)
|
||||
multimodal_projector_bias=config.multimodal_projector_bias,
|
||||
)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
@@ -330,14 +350,16 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.model.make_empty_intermediate_tensors)
|
||||
self.language_model.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[LlavaNextVideoPixelInputs]:
|
||||
"""
|
||||
A legal video input should have the following dimensions:
|
||||
{
|
||||
"pixel_values_videos" :
|
||||
"pixel_values_videos" :
|
||||
list[b, Tensor(nb_frames, nb_channels, height, width)]
|
||||
}
|
||||
"""
|
||||
@@ -347,12 +369,14 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
return LlavaNextVideoPixelInputs(type="pixel_values_videos",
|
||||
data=pixel_values_videos,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w,
|
||||
})
|
||||
return LlavaNextVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
data=pixel_values_videos,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w,
|
||||
},
|
||||
)
|
||||
|
||||
def _video_pixels_to_features(
|
||||
self,
|
||||
@@ -377,31 +401,31 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if isinstance(video_pixels, torch.Tensor):
|
||||
# TODO: support multiple videos per input
|
||||
b, num_videos, num_frames, c, h, w = video_pixels.shape
|
||||
assert (num_videos == 1)
|
||||
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c,
|
||||
h, w)
|
||||
assert num_videos == 1
|
||||
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c, h, w)
|
||||
stacked_embeddings = self._video_pixels_to_features(
|
||||
self.vision_tower, stacked_pixels)
|
||||
embeds = stacked_embeddings.view(b, num_frames,
|
||||
*stacked_embeddings.shape[1:])
|
||||
self.vision_tower, stacked_pixels
|
||||
)
|
||||
embeds = stacked_embeddings.view(
|
||||
b, num_frames, *stacked_embeddings.shape[1:]
|
||||
)
|
||||
|
||||
elif is_list_of(video_pixels, torch.Tensor):
|
||||
frames_per_videos = [v.shape[0] for v in video_pixels]
|
||||
stacked_pixels = torch.cat(video_pixels, dim=0)
|
||||
stacked_embeddings = self._video_pixels_to_features(
|
||||
self.vision_tower, stacked_pixels)
|
||||
self.vision_tower, stacked_pixels
|
||||
)
|
||||
embeds = torch.split(stacked_embeddings, frames_per_videos, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type of video input {type(video_pixels)}")
|
||||
raise ValueError(f"Unsupported type of video input {type(video_pixels)}")
|
||||
|
||||
return [e.flatten(0, 1) for e in embeds]
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
if video_input is None:
|
||||
return []
|
||||
@@ -425,10 +449,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -438,8 +461,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
# This model doesn't support images for now
|
||||
|
||||
Reference in New Issue
Block a user