Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -3,43 +3,59 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union)
from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union
import torch
import torch.nn as nn
from transformers import (BatchFeature, Mistral3Config, PixtralVisionConfig,
PretrainedConfig)
from transformers import (
BatchFeature,
Mistral3Config,
PixtralVisionConfig,
PretrainedConfig,
)
from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vision_encoder_info
@@ -67,38 +83,43 @@ class Mistral3PatchMerger(nn.Module):
Learned merging of spatial_merge_size ** 2 patches
"""
def __init__(self, vision_hidden_size: int, spatial_merge_size: int,
patch_size: int):
def __init__(
self, vision_hidden_size: int, spatial_merge_size: int, patch_size: int
):
super().__init__()
self.vision_hidden_size = vision_hidden_size
self.spatial_merge_size = spatial_merge_size
self.patch_size = patch_size
self.merging_layer = nn.Linear(vision_hidden_size *
self.spatial_merge_size**2,
vision_hidden_size,
bias=False)
self.merging_layer = nn.Linear(
vision_hidden_size * self.spatial_merge_size**2,
vision_hidden_size,
bias=False,
)
def forward(self, image_features: torch.Tensor,
image_sizes: torch.Tensor) -> torch.Tensor:
image_sizes = [(image_size[0] // self.patch_size,
image_size[1] // self.patch_size)
for image_size in image_sizes]
def forward(
self, image_features: torch.Tensor, image_sizes: torch.Tensor
) -> torch.Tensor:
image_sizes = [
(image_size[0] // self.patch_size, image_size[1] // self.patch_size)
for image_size in image_sizes
]
tokens_per_image = [h * w for h, w in image_sizes]
d = image_features.shape[-1]
permuted_tensor = []
for image_index, image_tokens in enumerate(
image_features.split(tokens_per_image)):
image_features.split(tokens_per_image)
):
# Reshape image_tokens into a 2D grid
h, w = image_sizes[image_index]
image_grid = image_tokens.view(h, w, d).permute(2, 0,
1).unsqueeze(0)
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
grid = torch.nn.functional.unfold(
image_grid,
kernel_size=self.spatial_merge_size,
stride=self.spatial_merge_size)
stride=self.spatial_merge_size,
)
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
permuted_tensor.append(grid)
@@ -108,38 +129,45 @@ class Mistral3PatchMerger(nn.Module):
class Mistral3MultiModalProjector(nn.Module):
def __init__(self,
vision_hidden_size: int,
text_hidden_size: int,
spatial_merge_size: int,
patch_size: int,
projector_hidden_act: str,
multimodal_projector_bias: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
vision_hidden_size: int,
text_hidden_size: int,
spatial_merge_size: int,
patch_size: int,
projector_hidden_act: str,
multimodal_projector_bias: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.norm = RMSNorm(vision_hidden_size, eps=1e-5)
self.patch_merger = Mistral3PatchMerger(
vision_hidden_size=vision_hidden_size,
spatial_merge_size=spatial_merge_size,
patch_size=patch_size)
patch_size=patch_size,
)
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_1")
self.linear_1 = ColumnParallelLinear(
vision_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_1",
)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_2")
self.linear_2 = RowParallelLinear(
text_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
prefix=f"{prefix}.linear_2",
)
def forward(self, image_features: torch.Tensor,
image_sizes: torch.Tensor) -> torch.Tensor:
def forward(
self, image_features: torch.Tensor, image_sizes: torch.Tensor
) -> torch.Tensor:
image_features = self.norm(image_features)
image_features = self.patch_merger(image_features, image_sizes)
hidden_states, _ = self.linear_1(image_features)
@@ -160,7 +188,6 @@ class LlavaLikeProcessor(Protocol):
class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> LlavaLikeConfig:
return self.ctx.get_hf_config(Mistral3Config)
@@ -196,7 +223,6 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@@ -213,29 +239,26 @@ class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides)
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
class Mistral3ProcessingInfo(BaseLlavaProcessingInfo):
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
class Mistral3MultiModalProcessor(
BaseMultiModalProcessor[Mistral3ProcessingInfo]):
class Mistral3MultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
@@ -252,7 +275,6 @@ class Mistral3MultiModalProcessor(
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
# Avoid padding since we need the output for each image to be
# independent of other images for the cache to work correctly
image_sizes = processed_outputs["image_sizes"]
@@ -316,7 +338,8 @@ class Mistral3MultiModalProcessor(
def _build_mistral3_info(
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
hf_config = ctx.get_hf_config(Mistral3Config)
assert isinstance(hf_config.vision_config, PixtralVisionConfig)
return Mistral3ProcessingInfo(ctx)
@@ -339,7 +362,7 @@ def _build_mistral3_processor(
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
"""Determine the number of hidden layers to initialize up to in the
visual encoder.
Args:
hf_config: Model config with vision feature layer(s).
"""
@@ -350,10 +373,10 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
return _get_layer_index(feature_layers, num_hidden_layers)
# If we have multiple feature layers, initialize up to the deepest one
elif isinstance(feature_layers, (list, tuple)):
return max(
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
" is not supported")
return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
raise TypeError(
f"vision_layer_feature type: {type(feature_layers)} is not supported"
)
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
@@ -396,13 +419,14 @@ def init_vision_tower_for_llava(
@MULTIMODAL_REGISTRY.register_processor(
_build_mistral3_processor,
info=_build_mistral3_info,
dummy_inputs=Mistral3DummyInputsBuilder)
class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
SupportsMultiModal, SupportsPP):
dummy_inputs=Mistral3DummyInputsBuilder,
)
class Mistral3ForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
"gate_up_proj": ["gate_proj", "up_proj"],
}
hf_to_vllm_mapper = WeightsMapper(
@@ -412,7 +436,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -433,11 +458,15 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if (config.text_config.architectures is None
and config.text_config.model_type == "mistral"):
if (
config.text_config.architectures is None
and config.text_config.model_type == "mistral"
):
config.text_config.architectures = ["MistralForCausalLM"]
if (config.projector_hidden_act is None
and config.vision_config.hidden_act == "gelu"):
if (
config.projector_hidden_act is None
and config.vision_config.hidden_act == "gelu"
):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
@@ -446,7 +475,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = Mistral3MultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
@@ -455,7 +485,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
patch_size=config.vision_config.patch_size,
multimodal_projector_bias=config.multimodal_projector_bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
else:
self.vision_tower = None
self.multi_modal_projector = None
@@ -467,10 +498,12 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]:
self, **kwargs: object
) -> Optional[Mistral3ImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
@@ -479,8 +512,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
assert pixel_values is not None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)
return Mistral3ImagePixelInputs(
type="pixel_values_pixtral",
@@ -494,8 +528,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
if image_input["type"] == "image_embeds":
return image_input["data"]
image_sizes = [(img.shape[-2], img.shape[-1])
for img in image_input["pixel_values"]]
image_sizes = [
(img.shape[-2], img.shape[-1]) for img in image_input["pixel_values"]
]
image_features = self.vision_tower(image_input["pixel_values"])
@@ -507,19 +542,19 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
for image_feature in image_features
]
image_embeds = self.multi_modal_projector(torch.cat(image_features),
image_sizes)
image_embeds = self.multi_modal_projector(
torch.cat(image_features), image_sizes
)
if len(feature_sizes) > 1:
image_embeds = torch.split(image_embeds, feature_sizes)
else:
image_embeds = (image_embeds, )
image_embeds = (image_embeds,)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
@@ -576,10 +611,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
@@ -589,8 +623,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes = ["vision_tower.", "multi_modal_projector."]
@@ -605,4 +638,5 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower")
tower_model="vision_tower",
)