[Frontend] Gemma3n audio transcriptions/translations endpoint (#23735)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Nicolò Lucchesi
2025-09-01 12:07:46 +02:00
committed by GitHub
parent 107284959a
commit d46934b229
9 changed files with 189 additions and 63 deletions

View File

@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, TypedDict, Union, cast
from typing import Any, Literal, Optional, TypedDict, Union, cast
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BatchFeature
@@ -13,7 +14,8 @@ from transformers.models.gemma3n import (Gemma3nAudioConfig,
Gemma3nVisionConfig)
from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import VllmConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
@@ -21,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
@@ -40,7 +43,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
@@ -410,7 +414,10 @@ class Gemma3nMultimodalEmbedder(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
info=Gemma3nProcessingInfo,
dummy_inputs=Gemma3nDummyInputsBuilder)
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsTranscription):
supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -694,3 +701,53 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
return "<audio_soft_token>"
else:
raise ValueError(f"Unsupported modality: {modality}")
@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
"""
Gemma3n supports "free-form" transcription.
We fix its prompt here to standardize transcriptions/translations
requests.
"""
# Transcribe this audio [into <>] | for transcription
# Translate this audio [from <> into <>] | for translation
prompt = "<start_of_turn>user\n"
prompt += "Transcribe" if task_type == "transcribe" else "Translate"
prompt += " this audio"
# We assume the language is a valid ISO 639-1 code.
full_lang_name = cls.supported_languages.get(language, "")
# Translation only for now
full_lang_name_to = cls.supported_languages.get(to_language, "")
if task_type == "transcribe" and full_lang_name:
prompt += f" into {full_lang_name}"
elif task_type == "translate":
if full_lang_name:
prompt += f" from {full_lang_name}"
if full_lang_name_to:
prompt += f" into {full_lang_name_to}"
prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"
audio = (audio, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt}
return cast(PromptType, prompts_dict)
@classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig,
task_type: str) -> SpeechToTextConfig:
return SpeechToTextConfig(
# Let's set this to 30 as suggested in the docs for now, although
# the model is only limited by its context length.
max_audio_clip_s=30,
sample_rate=16000,
# TODO enable chunking after more thorough testing.
min_energy_split_window_size=None,
)

View File

@@ -700,8 +700,10 @@ class SupportsTranscription(Protocol):
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
"""Get the prompt for the ASR model.
The model has control over the construction, as long as it
returns a valid PromptType."""

View File

@@ -5,7 +5,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from math import ceil
from typing import Optional, Union, cast
from typing import Literal, Optional, Union, cast
import numpy as np
import regex as re
@@ -455,8 +455,10 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_generation_prompt(cls, audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate),
format="wav") # lossless

View File

@@ -4,7 +4,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from typing import Optional, TypedDict, Union, cast
from typing import Literal, Optional, TypedDict, Union, cast
import numpy as np
import torch
@@ -783,8 +783,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: Optional[str],
task_type: str,
request_prompt: str) -> PromptType:
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the Whisper prompt")