Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -7,12 +7,13 @@ import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
get_anyres_image_grid_shape,
unpad_image,
)
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
@@ -21,13 +22,20 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import (BaseLlavaMultiModalProcessor, LlavaDummyInputsBuilder,
init_vision_tower_for_llava)
from .llava import (
BaseLlavaMultiModalProcessor,
LlavaDummyInputsBuilder,
init_vision_tower_for_llava,
)
from .llava_next import LlavaNextProcessingInfo
from .pixtral import PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
from .utils import (
AutoWeightsLoader,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
class MiniMaxVL01ImagePixelInputs(TensorSchema):
@@ -42,10 +50,12 @@ class MiniMaxVL01ImagePixelInputs(TensorSchema):
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})]
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"}),
]
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
# This should be in `(height, width)` format.
@@ -58,36 +68,43 @@ class MiniMaxVL01ImageEmbeddingInputs(TensorSchema):
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs,
MiniMaxVL01ImageEmbeddingInputs]
MiniMaxVL01ImageInputs = Union[
MiniMaxVL01ImagePixelInputs, MiniMaxVL01ImageEmbeddingInputs
]
class MiniMaxVL01MultiModalProjector(nn.Module):
def __init__(self,
vision_hidden_size: int,
text_hidden_size: int,
projector_hidden_act: str,
multimodal_projector_bias: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
vision_hidden_size: int,
text_hidden_size: int,
projector_hidden_act: str,
multimodal_projector_bias: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_1")
self.linear_1 = ColumnParallelLinear(
vision_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_1",
)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_2")
self.linear_2 = RowParallelLinear(
text_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_2",
)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear_1(image_features)
@@ -101,15 +118,13 @@ class MiniMaxVL01DummyInputsBuilder(LlavaDummyInputsBuilder):
class MiniMaxVL01ProcessingInfo(LlavaNextProcessingInfo):
def get_hf_config(self): # Need to override the config type
return self.ctx.get_hf_config(PretrainedConfig)
def get_hf_processor(self, **kwargs: object):
hf_processor = self.ctx.get_hf_processor(**kwargs)
image_processor = hf_processor.image_processor
image_processor.anyres_preprocess = (
image_processor.anyres_for_vllm_preprocess)
image_processor.anyres_preprocess = image_processor.anyres_for_vllm_preprocess
return hf_processor
@@ -118,8 +133,8 @@ class MiniMaxVL01ProcessingInfo(LlavaNextProcessingInfo):
class MiniMaxVL01MultiModalProcessor(
BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo]):
BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo]
):
def _call_hf_processor(
self,
prompt: str,
@@ -162,13 +177,12 @@ class MiniMaxVL01MultiModalProcessor(
@MULTIMODAL_REGISTRY.register_processor(
MiniMaxVL01MultiModalProcessor,
info=MiniMaxVL01ProcessingInfo,
dummy_inputs=MiniMaxVL01DummyInputsBuilder)
class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
dummy_inputs=MiniMaxVL01DummyInputsBuilder,
)
class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
"gate_up_proj": ["gate_proj", "up_proj"],
}
@classmethod
@@ -193,16 +207,17 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = MiniMaxVL01MultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=True,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
@@ -215,15 +230,15 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.pad_token_id = self.config.pad_token_id
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
@@ -231,55 +246,55 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
feature_select_strategy = self.config.vision_feature_select_strategy
return tuple(
vision_tower(p, feature_select_strategy=feature_select_strategy)
for p in pixel_values)
for p in pixel_values
)
# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
def pack_image_features(self, image_features: list[torch.Tensor],
image_sizes: torch.Tensor):
def pack_image_features(
self, image_features: list[torch.Tensor], image_sizes: torch.Tensor
):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = (self.config.vision_config.image_size //
self.config.vision_config.patch_size)
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with "
"the image size.")
"The number of patches is not consistent with the image size."
)
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height,
num_patch_width, height,
width, -1)
image_feature = image_feature.permute(4, 0, 2, 1,
3).contiguous()
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature,
image_sizes[image_idx])
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1).to(
image_feature.dtype),
self.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature),
dim=0)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature,
self.image_newline[None].to(image_feature)),
dim=0)
(image_feature, self.image_newline[None].to(image_feature)), dim=0
)
new_image_features.append(image_feature)
return new_image_features
@@ -305,9 +320,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
if isinstance(image_features, torch.Tensor):
return self.multi_modal_projector(image_features)
feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
image_embeds = self.multi_modal_projector(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
@@ -315,7 +328,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.pack_image_features(image_embeds, image_sizes)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]:
self, **kwargs: object
) -> Optional[MiniMaxVL01ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
@@ -325,12 +339,14 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is not None and image_sizes is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
raise ValueError(
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
)
return MiniMaxVL01ImagePixelInputs(
type="pixel_values",
@@ -340,8 +356,10 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return MiniMaxVL01ImageEmbeddingInputs(
type="image_embeds",
@@ -350,8 +368,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
@@ -366,7 +383,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
@@ -378,10 +394,9 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
@@ -391,7 +406,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)