Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -5,6 +5,7 @@ from typing import Annotated, Any, Literal, Optional, Union, cast
import numpy as np
import torch
# yapf: disable
from torch import nn
from transformers import AutoModel, BatchFeature
@@ -44,10 +45,14 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix)
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
logger = init_logger(__name__)
@@ -64,6 +69,7 @@ class Gemma3nImagePixelInputs(TensorSchema):
- h: Height of each patch
- w: Width of each patch
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
@@ -75,6 +81,7 @@ class Gemma3nAudioInputs(TensorSchema):
- s: seq_length
- f: num_features
"""
type: Literal["audio"] = "audio"
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
@@ -84,7 +91,6 @@ Gemma3nImageInputs = Gemma3nImagePixelInputs
class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma3nConfig)
@@ -95,9 +101,8 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
return {"image": None, "audio": None}
def get_max_tokens_per_item(
self, seq_len: int,
mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]:
self, seq_len: int, mm_counts: Mapping[str, int]
) -> Optional[Mapping[str, int]]:
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
def get_image_repl(
@@ -109,7 +114,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
) -> str:
"""
Get the replacement text for image tokens.
For Gemma3n, this should return the full_image_sequence which includes
BOI token, repeated image tokens, and EOI token.
"""
@@ -117,7 +122,8 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
processor = self.get_hf_processor()
return PromptUpdateDetails.select_token_id(
processor.full_image_sequence, processor.image_token_id)
processor.full_image_sequence, processor.image_token_id
)
def get_audio_repl(
self,
@@ -126,7 +132,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
) -> str:
"""
Get the replacement text for audio tokens.
For Gemma3n, this should return the full_audio_sequence which includes
BOA token, repeated audio tokens, and EOA token.
"""
@@ -135,11 +141,11 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
# Return the full audio sequence as defined by the processor
return PromptUpdateDetails.select_token_id(
processor.full_audio_sequence, processor.audio_token_id)
processor.full_audio_sequence, processor.audio_token_id
)
class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
@@ -159,7 +165,9 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
processor = self.info.get_hf_processor()
audio_feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501
audio_feature_extractor: Gemma3nAudioFeatureExtractor = (
processor.feature_extractor
) # noqa: E501
audio_len = audio_feature_extractor.fft_length
image_processor: SiglipImageProcessorFast = processor.image_processor
img_width = image_processor.size.get("width", 224)
@@ -169,21 +177,19 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"image":
self._get_dummy_images(width=img_width,
height=img_height,
num_images=num_images,
overrides=image_overrides),
"audio":
self._get_dummy_audios(length=audio_len,
num_audios=num_audios,
overrides=audio_overrides)
"image": self._get_dummy_images(
width=img_width,
height=img_height,
num_images=num_images,
overrides=image_overrides,
),
"audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides
),
}
class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
):
class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_hf_processor().feature_extractor
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
@@ -195,12 +201,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# HF Transformers audio processor no longer accepts `audios` key.
# We pop `audios` and replace it with `audio` key to suppress
# the warning.
if 'audios' in mm_data:
mm_data['audio'] = mm_data.pop('audios')
if "audios" in mm_data:
mm_data["audio"] = mm_data.pop("audios")
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
@@ -208,15 +213,17 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
tok_kwargs,
)
if 'input_features' in processed_outputs:
if "input_features" in processed_outputs:
# Padding enables audio_tower to run in batched mode
processed_outputs["input_features_padded"] = \
processed_outputs["input_features"]
processed_outputs["input_features_padded"] = processed_outputs[
"input_features"
]
# Unpad features here since we need the output of each item to be
# independent of other items for the cache to work correctly
unpadded_features = [
f[mask] for f, mask in zip(
f[mask]
for f, mask in zip(
processed_outputs["input_features"],
processed_outputs["input_features_mask"],
)
@@ -229,7 +236,6 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
input_features_padded=MultiModalFieldConfig.batched("audio"),
@@ -264,21 +270,25 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
modality="image",
target=image_token,
replacement=get_replacement_image,
))
)
)
# Handle audio tokens
if "audio" in mm_items:
audio_token = hf_processor.audio_token
def get_replacement_audio(item_idx: int):
return self.info.get_audio_repl(processor=hf_processor, )
return self.info.get_audio_repl(
processor=hf_processor,
)
prompt_updates.append(
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_audio,
))
)
)
return prompt_updates
@@ -287,8 +297,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
prompt: list[int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
token_ids, res = super()._apply_token_matches(prompt,
mm_prompt_updates)
token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
@@ -347,8 +356,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = super()._find_mm_placeholders(repl_token_ids,
mm_prompt_updates)
repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates)
return {
modality: [
@@ -358,14 +366,15 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens,
is_embed=p.is_embed,
) for p in placeholders
)
for p in placeholders
]
for modality, placeholders in repls.items()
}
class Gemma3nMultimodalEmbedder(nn.Module):
"""Embeds token ids or soft tokens for multimodal content into language
"""Embeds token ids or soft tokens for multimodal content into language
model space."""
def __init__(
@@ -425,7 +434,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
""" # noqa: E501
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds")
"You must specify exactly one of input_ids or inputs_embeds"
)
if inputs_embeds is not None:
emb_norm = self.soft_embedding_norm(inputs_embeds)
@@ -437,11 +447,14 @@ class Gemma3nMultimodalEmbedder(nn.Module):
return self.embedding_post_projection_norm(emb_norm_proj)
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
info=Gemma3nProcessingInfo,
dummy_inputs=Gemma3nDummyInputsBuilder)
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsTranscription):
@MULTIMODAL_REGISTRY.register_processor(
Gemma3nMultiModalProcessor,
info=Gemma3nProcessingInfo,
dummy_inputs=Gemma3nDummyInputsBuilder,
)
class Gemma3nForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsTranscription
):
merge_by_field_config = True
supported_languages = ISO639_1_SUPPORTED_LANGS
@@ -468,7 +481,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
"model": "language_model.model",
})
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -482,10 +496,12 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.audio_tower = AutoModel.from_config(config=config.audio_config)
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
config.text_config)
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
config.text_config)
self.embed_vision = Gemma3nMultimodalEmbedder(
config.vision_config, config.text_config
)
self.embed_audio = Gemma3nMultimodalEmbedder(
config.audio_config, config.text_config
)
self.language_model: nn.Module = init_vllm_registered_model(
vllm_config=vllm_config,
@@ -501,10 +517,12 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config.text_config.num_hidden_layers,
self.config.text_config.hidden_size_per_layer_input,
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype)
dtype=self.language_model.model.embed_tokens.weight.dtype,
)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
self, **kwargs: object
) -> Optional[Gemma3nImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
# TODO is this the case?
@@ -515,8 +533,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
return Gemma3nImagePixelInputs(pixel_values=pixel_values)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
self, **kwargs: object
) -> Optional[Gemma3nAudioInputs]:
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
@@ -536,14 +554,20 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values", "image_embeds"
) and "image" not in mm_input_by_modality:
mm_input_by_modality[
"image"] = self._parse_and_validate_image_input(**kwargs)
if input_key == "input_features_padded" \
and "audio" not in mm_input_by_modality:
mm_input_by_modality[
"audio"] = self._parse_and_validate_audio_input(**kwargs)
if (
input_key in ("pixel_values", "image_embeds")
and "image" not in mm_input_by_modality
):
mm_input_by_modality["image"] = self._parse_and_validate_image_input(
**kwargs
)
if (
input_key == "input_features_padded"
and "audio" not in mm_input_by_modality
):
mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
**kwargs
)
return mm_input_by_modality
def _process_image_input(
@@ -553,16 +577,20 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
vision_outputs = self.vision_tower(pixel_values=pixel_values,
do_pooling=False,
return_dict=True).last_hidden_state
vision_outputs = self.vision_tower(
pixel_values=pixel_values, do_pooling=False, return_dict=True
).last_hidden_state
# TODO try to avoid copy here
# (batch, channels, height, width) to (batch, height * width, channels)
vision_outputs = vision_outputs.reshape(
vision_outputs.shape[0],
self.config.vision_config.hidden_size,
self.config.vision_soft_tokens_per_image,
).permute(0, 2, 1).contiguous()
vision_outputs = (
vision_outputs.reshape(
vision_outputs.shape[0],
self.config.vision_config.hidden_size,
self.config.vision_soft_tokens_per_image,
)
.permute(0, 2, 1)
.contiguous()
)
# Normalize and embed the soft tokens into language model space.
vision_outputs *= self.config.vision_config.hidden_size**0.5
# Return a list of embeddings instead of a batched tensor
@@ -576,8 +604,9 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
# Run on padded features to enable batching
input_features = audio_input["input_features_padded"].squeeze(1)
input_features_mask = audio_input["input_features_mask"].squeeze(1)
audio_outputs, audio_mask = self.audio_tower(input_features,
~input_features_mask)
audio_outputs, audio_mask = self.audio_tower(
input_features, ~input_features_mask
)
audio_features = self.embed_audio(inputs_embeds=audio_outputs)
# ruff: noqa
@@ -587,30 +616,29 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
# the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab.
# TODO precompute and cache padding
audio_padding_toks = torch.tensor([[self.vocab_size - 1]],
dtype=torch.long,
device=audio_features.device)
audio_padding_toks = torch.tensor(
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
)
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
audio_features = torch.where(audio_mask.unsqueeze(-1),
audio_padding_embs, audio_features)
audio_features = torch.where(
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim)
audio_batch_size, extra_padding_tokens, audio_embed_dim
)
audio_features = torch.cat((audio_features, extra_padding_features),
dim=1)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
# Return a list of embeddings instead of a batched tensor
return audio_features.unbind(0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs)
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if mm_input_by_modality is None:
return []
@@ -640,12 +668,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
# them here, as the model forward has only access to the input_embeds.
if input_ids is not None:
per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
input_ids)
input_ids
)
per_layer_inputs = per_layer_inputs.reshape(
-1, self.config.text_config.num_hidden_layers,
self.config.text_config.hidden_size_per_layer_input)
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
per_layer_inputs)
-1,
self.config.text_config.num_hidden_layers,
self.config.text_config.hidden_size_per_layer_input,
)
self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_(
per_layer_inputs
)
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
@@ -658,12 +690,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object) -> IntermediateTensors:
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
@@ -672,7 +706,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
# select a chunk of pre-allocated PLEs. During normal execution,
# `get_input_embeddings` is called before forward, hence this slice
# will contain PLEs computed from the actual input_ids.
per_layer_inputs = self.per_layer_embeddings[:inputs_embeds.shape[0]]
per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]]
hidden_states = self.language_model.model(
input_ids,
@@ -680,7 +714,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
per_layer_inputs=per_layer_inputs,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs)
**kwargs,
)
return hidden_states
@@ -690,8 +725,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@@ -702,7 +736,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower")
tower_model="vision_tower",
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -714,16 +749,19 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError(f"Unsupported modality: {modality}")
@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str],
) -> PromptType:
"""
Gemma3n supports "free-form" transcription.
We fix its prompt here to standardize transcriptions/translations
We fix its prompt here to standardize transcriptions/translations
requests.
"""
# Transcribe this audio [into <>] | for transcription
@@ -752,8 +790,9 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
return cast(PromptType, prompts_dict)
@classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig,
task_type: str) -> SpeechToTextConfig:
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
return SpeechToTextConfig(
# Let's set this to 30 as suggested in the docs for now, although
# the model is only limited by its context length.