Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -16,29 +16,44 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate,
PromptUpdateDetails,
replace_token_matches)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
replace_token_matches,
)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
logger = init_logger(__name__)
@@ -53,6 +68,7 @@ class Gemma3ImagePixelInputs(TensorSchema):
- w: Width of each patch
- bn: Batch size * number of images
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")]
@@ -64,7 +80,6 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma3Config)
@@ -107,19 +122,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {
"do_pan_and_scan", "pan_and_scan_min_crop_size",
processor,
{
"do_pan_and_scan",
"pan_and_scan_min_crop_size",
"pan_and_scan_max_num_crops",
"pan_and_scan_min_ratio_to_activate"
})
"pan_and_scan_min_ratio_to_activate",
},
)
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
pan_and_scan_min_crop_size = images_kwargs[
"pan_and_scan_min_crop_size"]
pan_and_scan_max_num_crops = images_kwargs[
"pan_and_scan_max_num_crops"]
pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"]
pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
pan_and_scan_min_ratio_to_activate = images_kwargs[
"pan_and_scan_min_ratio_to_activate"]
"pan_and_scan_min_ratio_to_activate"
]
if not do_pan_and_scan:
return 0
@@ -127,7 +144,8 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if envs.VLLM_USE_V1:
logger.warning_once(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used.")
"because of the simplified attention pattern being used."
)
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:
@@ -187,10 +205,10 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
image_text = (
f"Here is the original image {boi_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
f"crops to help you see better {crops_image_tokens}"
)
repl_full = image_text.replace(boi_token,
processor.full_image_sequence)
repl_full = image_text.replace(boi_token, processor.full_image_sequence)
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
@@ -221,7 +239,8 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {"pan_and_scan_max_num_crops"})
processor, {"pan_and_scan_max_num_crops"}
)
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
# Result in the max possible feature size (h:w = max_num_crops:1)
@@ -229,7 +248,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@@ -246,22 +264,21 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides)
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
@@ -278,20 +295,22 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems))
parsed_images = (
self._get_data_parser()
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
)
image_sizes = [
parsed_images.get_image_size(i)
for i in range(len(parsed_images))
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
processor=hf_processor)
self.info.get_num_crops(
image_width=size.width,
image_height=size.height,
processor=hf_processor,
)
for size in image_sizes
]
processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
@@ -306,8 +325,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
)
@@ -343,8 +361,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
prompt: list[int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
token_ids, res = super()._apply_token_matches(prompt,
mm_prompt_updates)
token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
@@ -403,8 +420,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = super()._find_mm_placeholders(repl_token_ids,
mm_prompt_updates)
repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates)
return {
modality: [
@@ -414,39 +430,43 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens,
is_embed=p.is_embed,
) for p in placeholders
)
for p in placeholders
]
for modality, placeholders in repls.items()
}
class Gemma3MultiModalProjector(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.zeros(config.vision_config.hidden_size,
config.text_config.hidden_size))
torch.zeros(
config.vision_config.hidden_size, config.text_config.hidden_size
)
)
self.mm_soft_emb_norm = GemmaRMSNorm(
config.vision_config.hidden_size,
eps=config.vision_config.layer_norm_eps)
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
)
self.patches_per_image = int(config.vision_config.image_size //
config.vision_config.patch_size)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size,
stride=self.kernel_size)
self.avg_pool = nn.AvgPool2d(
kernel_size=self.kernel_size, stride=self.kernel_size
)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image,
self.patches_per_image)
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
@@ -456,15 +476,19 @@ class Gemma3MultiModalProjector(nn.Module):
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight)
normed_vision_outputs, self.mm_input_projection_weight
)
return projected_vision_outputs.type_as(vision_outputs)
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
@MULTIMODAL_REGISTRY.register_processor(
Gemma3MultiModalProcessor,
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder,
)
class Gemma3ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
):
merge_by_field_config = True
packed_modules_mapping = {
@@ -486,7 +510,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -504,10 +529,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
prefix=maybe_prefix(
prefix, "vision_tower"))
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.language_model = init_vllm_registered_model(
@@ -524,14 +550,16 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
@property
def dtype(self):
return next(self.parameters()).dtype
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
self, **kwargs: object
) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
@@ -541,12 +569,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={
"h": image_size,
"w": image_size
})
return Gemma3ImagePixelInputs(
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={"h": image_size, "w": image_size},
)
def _image_pixels_to_features(
self,
@@ -570,35 +597,36 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
)
image_embeds = self.multi_modal_projector(image_features)
return [
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
]
return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
return self._process_image_input(image_input)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object) -> IntermediateTensors:
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs)
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
@@ -646,7 +674,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# Consider the bidirectional attention between image tokens.
img_mask = torch.zeros_like(global_attn_mask)
img_pos = (input_token_ids == self.config.image_token_index)
img_pos = input_token_ids == self.config.image_token_index
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
@@ -656,10 +684,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if sliding_window is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
@@ -671,8 +699,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@@ -683,4 +710,5 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower")
tower_model="vision_tower",
)