Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Annotated, Any, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# yapf: disable
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BatchFeature
|
||||
@@ -44,10 +45,14 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -64,6 +69,7 @@ class Gemma3nImagePixelInputs(TensorSchema):
|
||||
- h: Height of each patch
|
||||
- w: Width of each patch
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
|
||||
@@ -75,6 +81,7 @@ class Gemma3nAudioInputs(TensorSchema):
|
||||
- s: seq_length
|
||||
- f: num_features
|
||||
"""
|
||||
|
||||
type: Literal["audio"] = "audio"
|
||||
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
|
||||
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
|
||||
@@ -84,7 +91,6 @@ Gemma3nImageInputs = Gemma3nImagePixelInputs
|
||||
|
||||
|
||||
class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Gemma3nConfig)
|
||||
|
||||
@@ -95,9 +101,8 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
return {"image": None, "audio": None}
|
||||
|
||||
def get_max_tokens_per_item(
|
||||
self, seq_len: int,
|
||||
mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]:
|
||||
|
||||
self, seq_len: int, mm_counts: Mapping[str, int]
|
||||
) -> Optional[Mapping[str, int]]:
|
||||
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
|
||||
|
||||
def get_image_repl(
|
||||
@@ -109,7 +114,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
) -> str:
|
||||
"""
|
||||
Get the replacement text for image tokens.
|
||||
|
||||
|
||||
For Gemma3n, this should return the full_image_sequence which includes
|
||||
BOI token, repeated image tokens, and EOI token.
|
||||
"""
|
||||
@@ -117,7 +122,8 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
processor.full_image_sequence, processor.image_token_id)
|
||||
processor.full_image_sequence, processor.image_token_id
|
||||
)
|
||||
|
||||
def get_audio_repl(
|
||||
self,
|
||||
@@ -126,7 +132,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
) -> str:
|
||||
"""
|
||||
Get the replacement text for audio tokens.
|
||||
|
||||
|
||||
For Gemma3n, this should return the full_audio_sequence which includes
|
||||
BOA token, repeated audio tokens, and EOA token.
|
||||
"""
|
||||
@@ -135,11 +141,11 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
# Return the full audio sequence as defined by the processor
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
processor.full_audio_sequence, processor.audio_token_id)
|
||||
processor.full_audio_sequence, processor.audio_token_id
|
||||
)
|
||||
|
||||
|
||||
class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
@@ -159,7 +165,9 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
processor = self.info.get_hf_processor()
|
||||
audio_feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501
|
||||
audio_feature_extractor: Gemma3nAudioFeatureExtractor = (
|
||||
processor.feature_extractor
|
||||
) # noqa: E501
|
||||
audio_len = audio_feature_extractor.fft_length
|
||||
image_processor: SiglipImageProcessorFast = processor.image_processor
|
||||
img_width = image_processor.size.get("width", 224)
|
||||
@@ -169,21 +177,19 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=img_width,
|
||||
height=img_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
"image": self._get_dummy_images(
|
||||
width=img_width,
|
||||
height=img_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
),
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len, num_audios=num_audios, overrides=audio_overrides
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
):
|
||||
|
||||
class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_hf_processor().feature_extractor
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
@@ -195,12 +201,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
|
||||
# HF Transformers audio processor no longer accepts `audios` key.
|
||||
# We pop `audios` and replace it with `audio` key to suppress
|
||||
# the warning.
|
||||
if 'audios' in mm_data:
|
||||
mm_data['audio'] = mm_data.pop('audios')
|
||||
if "audios" in mm_data:
|
||||
mm_data["audio"] = mm_data.pop("audios")
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt,
|
||||
mm_data,
|
||||
@@ -208,15 +213,17 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
tok_kwargs,
|
||||
)
|
||||
|
||||
if 'input_features' in processed_outputs:
|
||||
if "input_features" in processed_outputs:
|
||||
# Padding enables audio_tower to run in batched mode
|
||||
processed_outputs["input_features_padded"] = \
|
||||
processed_outputs["input_features"]
|
||||
processed_outputs["input_features_padded"] = processed_outputs[
|
||||
"input_features"
|
||||
]
|
||||
|
||||
# Unpad features here since we need the output of each item to be
|
||||
# independent of other items for the cache to work correctly
|
||||
unpadded_features = [
|
||||
f[mask] for f, mask in zip(
|
||||
f[mask]
|
||||
for f, mask in zip(
|
||||
processed_outputs["input_features"],
|
||||
processed_outputs["input_features_mask"],
|
||||
)
|
||||
@@ -229,7 +236,6 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
input_features_padded=MultiModalFieldConfig.batched("audio"),
|
||||
@@ -264,21 +270,25 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
modality="image",
|
||||
target=image_token,
|
||||
replacement=get_replacement_image,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Handle audio tokens
|
||||
if "audio" in mm_items:
|
||||
audio_token = hf_processor.audio_token
|
||||
|
||||
def get_replacement_audio(item_idx: int):
|
||||
return self.info.get_audio_repl(processor=hf_processor, )
|
||||
return self.info.get_audio_repl(
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
prompt_updates.append(
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=audio_token,
|
||||
replacement=get_replacement_audio,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_updates
|
||||
|
||||
@@ -287,8 +297,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
prompt: list[int],
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
|
||||
token_ids, res = super()._apply_token_matches(prompt,
|
||||
mm_prompt_updates)
|
||||
token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates)
|
||||
|
||||
# "\n\n\n" and "\n\n\n\n" are single tokens
|
||||
# Since our replacement can insert "\n\n" next to "\n"
|
||||
@@ -347,8 +356,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
repl_token_ids.extend(repl_toks)
|
||||
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||
|
||||
repls = super()._find_mm_placeholders(repl_token_ids,
|
||||
mm_prompt_updates)
|
||||
repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates)
|
||||
|
||||
return {
|
||||
modality: [
|
||||
@@ -358,14 +366,15 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
start_idx=repl_orig_idxs[p.start_idx],
|
||||
tokens=p.tokens,
|
||||
is_embed=p.is_embed,
|
||||
) for p in placeholders
|
||||
)
|
||||
for p in placeholders
|
||||
]
|
||||
for modality, placeholders in repls.items()
|
||||
}
|
||||
|
||||
|
||||
class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
"""Embeds token ids or soft tokens for multimodal content into language
|
||||
"""Embeds token ids or soft tokens for multimodal content into language
|
||||
model space."""
|
||||
|
||||
def __init__(
|
||||
@@ -425,7 +434,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
""" # noqa: E501
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You must specify exactly one of input_ids or inputs_embeds")
|
||||
"You must specify exactly one of input_ids or inputs_embeds"
|
||||
)
|
||||
|
||||
if inputs_embeds is not None:
|
||||
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
||||
@@ -437,11 +447,14 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
return self.embedding_post_projection_norm(emb_norm_proj)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
|
||||
info=Gemma3nProcessingInfo,
|
||||
dummy_inputs=Gemma3nDummyInputsBuilder)
|
||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsTranscription):
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Gemma3nMultiModalProcessor,
|
||||
info=Gemma3nProcessingInfo,
|
||||
dummy_inputs=Gemma3nDummyInputsBuilder,
|
||||
)
|
||||
class Gemma3nForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsTranscription
|
||||
):
|
||||
merge_by_field_config = True
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
@@ -468,7 +481,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model": "language_model.model",
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@@ -482,10 +496,12 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.audio_tower = AutoModel.from_config(config=config.audio_config)
|
||||
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
|
||||
config.text_config)
|
||||
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
|
||||
config.text_config)
|
||||
self.embed_vision = Gemma3nMultimodalEmbedder(
|
||||
config.vision_config, config.text_config
|
||||
)
|
||||
self.embed_audio = Gemma3nMultimodalEmbedder(
|
||||
config.audio_config, config.text_config
|
||||
)
|
||||
|
||||
self.language_model: nn.Module = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
@@ -501,10 +517,12 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.config.text_config.num_hidden_layers,
|
||||
self.config.text_config.hidden_size_per_layer_input,
|
||||
device=self.language_model.model.embed_tokens.weight.device,
|
||||
dtype=self.language_model.model.embed_tokens.weight.dtype)
|
||||
dtype=self.language_model.model.embed_tokens.weight.dtype,
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[Gemma3nImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
# TODO is this the case?
|
||||
@@ -515,8 +533,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return Gemma3nImagePixelInputs(pixel_values=pixel_values)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
|
||||
|
||||
self, **kwargs: object
|
||||
) -> Optional[Gemma3nAudioInputs]:
|
||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||
if input_features_padded is None:
|
||||
return None
|
||||
@@ -536,14 +554,20 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values", "image_embeds"
|
||||
) and "image" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key == "input_features_padded" \
|
||||
and "audio" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"audio"] = self._parse_and_validate_audio_input(**kwargs)
|
||||
if (
|
||||
input_key in ("pixel_values", "image_embeds")
|
||||
and "image" not in mm_input_by_modality
|
||||
):
|
||||
mm_input_by_modality["image"] = self._parse_and_validate_image_input(
|
||||
**kwargs
|
||||
)
|
||||
if (
|
||||
input_key == "input_features_padded"
|
||||
and "audio" not in mm_input_by_modality
|
||||
):
|
||||
mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
|
||||
**kwargs
|
||||
)
|
||||
return mm_input_by_modality
|
||||
|
||||
def _process_image_input(
|
||||
@@ -553,16 +577,20 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = image_input["pixel_values"]
|
||||
vision_outputs = self.vision_tower(pixel_values=pixel_values,
|
||||
do_pooling=False,
|
||||
return_dict=True).last_hidden_state
|
||||
vision_outputs = self.vision_tower(
|
||||
pixel_values=pixel_values, do_pooling=False, return_dict=True
|
||||
).last_hidden_state
|
||||
# TODO try to avoid copy here
|
||||
# (batch, channels, height, width) to (batch, height * width, channels)
|
||||
vision_outputs = vision_outputs.reshape(
|
||||
vision_outputs.shape[0],
|
||||
self.config.vision_config.hidden_size,
|
||||
self.config.vision_soft_tokens_per_image,
|
||||
).permute(0, 2, 1).contiguous()
|
||||
vision_outputs = (
|
||||
vision_outputs.reshape(
|
||||
vision_outputs.shape[0],
|
||||
self.config.vision_config.hidden_size,
|
||||
self.config.vision_soft_tokens_per_image,
|
||||
)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
)
|
||||
# Normalize and embed the soft tokens into language model space.
|
||||
vision_outputs *= self.config.vision_config.hidden_size**0.5
|
||||
# Return a list of embeddings instead of a batched tensor
|
||||
@@ -576,8 +604,9 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# Run on padded features to enable batching
|
||||
input_features = audio_input["input_features_padded"].squeeze(1)
|
||||
input_features_mask = audio_input["input_features_mask"].squeeze(1)
|
||||
audio_outputs, audio_mask = self.audio_tower(input_features,
|
||||
~input_features_mask)
|
||||
audio_outputs, audio_mask = self.audio_tower(
|
||||
input_features, ~input_features_mask
|
||||
)
|
||||
audio_features = self.embed_audio(inputs_embeds=audio_outputs)
|
||||
|
||||
# ruff: noqa
|
||||
@@ -587,30 +616,29 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
|
||||
# the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab.
|
||||
# TODO precompute and cache padding
|
||||
audio_padding_toks = torch.tensor([[self.vocab_size - 1]],
|
||||
dtype=torch.long,
|
||||
device=audio_features.device)
|
||||
audio_padding_toks = torch.tensor(
|
||||
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
|
||||
)
|
||||
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
|
||||
audio_features = torch.where(audio_mask.unsqueeze(-1),
|
||||
audio_padding_embs, audio_features)
|
||||
audio_features = torch.where(
|
||||
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
|
||||
)
|
||||
|
||||
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
|
||||
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501
|
||||
extra_padding_features = audio_padding_embs.expand(
|
||||
audio_batch_size, extra_padding_tokens, audio_embed_dim)
|
||||
audio_batch_size, extra_padding_tokens, audio_embed_dim
|
||||
)
|
||||
|
||||
audio_features = torch.cat((audio_features, extra_padding_features),
|
||||
dim=1)
|
||||
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
|
||||
# Return a list of embeddings instead of a batched tensor
|
||||
return audio_features.unbind(0)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
||||
**kwargs)
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if mm_input_by_modality is None:
|
||||
return []
|
||||
|
||||
@@ -640,12 +668,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# them here, as the model forward has only access to the input_embeds.
|
||||
if input_ids is not None:
|
||||
per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
|
||||
input_ids)
|
||||
input_ids
|
||||
)
|
||||
per_layer_inputs = per_layer_inputs.reshape(
|
||||
-1, self.config.text_config.num_hidden_layers,
|
||||
self.config.text_config.hidden_size_per_layer_input)
|
||||
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
|
||||
per_layer_inputs)
|
||||
-1,
|
||||
self.config.text_config.num_hidden_layers,
|
||||
self.config.text_config.hidden_size_per_layer_input,
|
||||
)
|
||||
self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_(
|
||||
per_layer_inputs
|
||||
)
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
@@ -658,12 +690,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object) -> IntermediateTensors:
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
@@ -672,7 +706,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# select a chunk of pre-allocated PLEs. During normal execution,
|
||||
# `get_input_embeddings` is called before forward, hence this slice
|
||||
# will contain PLEs computed from the actual input_ids.
|
||||
per_layer_inputs = self.per_layer_embeddings[:inputs_embeds.shape[0]]
|
||||
per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]]
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids,
|
||||
@@ -680,7 +714,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -690,8 +725,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
@@ -702,7 +736,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector",
|
||||
tower_model="vision_tower")
|
||||
tower_model="vision_tower",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@@ -714,16 +749,19 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError(f"Unsupported modality: {modality}")
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str],
|
||||
) -> PromptType:
|
||||
"""
|
||||
Gemma3n supports "free-form" transcription.
|
||||
We fix its prompt here to standardize transcriptions/translations
|
||||
We fix its prompt here to standardize transcriptions/translations
|
||||
requests.
|
||||
"""
|
||||
# Transcribe this audio [into <>] | for transcription
|
||||
@@ -752,8 +790,9 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return cast(PromptType, prompts_dict)
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||
task_type: str) -> SpeechToTextConfig:
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
) -> SpeechToTextConfig:
|
||||
return SpeechToTextConfig(
|
||||
# Let's set this to 30 as suggested in the docs for now, although
|
||||
# the model is only limited by its context length.
|
||||
|
||||
Reference in New Issue
Block a user