[Frontend] Gemma3n audio transcriptions/translations endpoint (#23735)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Optional, TypedDict, Union, cast
|
||||
from typing import Any, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BatchFeature
|
||||
@@ -13,7 +14,8 @@ from transformers.models.gemma3n import (Gemma3nAudioConfig,
|
||||
Gemma3nVisionConfig)
|
||||
from transformers.models.siglip import SiglipImageProcessorFast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
@@ -21,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
@@ -40,7 +43,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
@@ -410,7 +414,10 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
|
||||
info=Gemma3nProcessingInfo,
|
||||
dummy_inputs=Gemma3nDummyInputsBuilder)
|
||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsTranscription):
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -694,3 +701,53 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
return "<audio_soft_token>"
|
||||
else:
|
||||
raise ValueError(f"Unsupported modality: {modality}")
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
"""
|
||||
Gemma3n supports "free-form" transcription.
|
||||
We fix its prompt here to standardize transcriptions/translations
|
||||
requests.
|
||||
"""
|
||||
# Transcribe this audio [into <>] | for transcription
|
||||
# Translate this audio [from <> into <>] | for translation
|
||||
prompt = "<start_of_turn>user\n"
|
||||
prompt += "Transcribe" if task_type == "transcribe" else "Translate"
|
||||
prompt += " this audio"
|
||||
|
||||
# We assume the language is a valid ISO 639-1 code.
|
||||
full_lang_name = cls.supported_languages.get(language, "")
|
||||
# Translation only for now
|
||||
full_lang_name_to = cls.supported_languages.get(to_language, "")
|
||||
|
||||
if task_type == "transcribe" and full_lang_name:
|
||||
prompt += f" into {full_lang_name}"
|
||||
elif task_type == "translate":
|
||||
if full_lang_name:
|
||||
prompt += f" from {full_lang_name}"
|
||||
if full_lang_name_to:
|
||||
prompt += f" into {full_lang_name_to}"
|
||||
|
||||
prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"
|
||||
|
||||
audio = (audio, stt_config.sample_rate)
|
||||
prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt}
|
||||
return cast(PromptType, prompts_dict)
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||
task_type: str) -> SpeechToTextConfig:
|
||||
return SpeechToTextConfig(
|
||||
# Let's set this to 30 as suggested in the docs for now, although
|
||||
# the model is only limited by its context length.
|
||||
max_audio_clip_s=30,
|
||||
sample_rate=16000,
|
||||
# TODO enable chunking after more thorough testing.
|
||||
min_energy_split_window_size=None,
|
||||
)
|
||||
|
||||
@@ -700,8 +700,10 @@ class SupportsTranscription(Protocol):
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str], task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
"""Get the prompt for the ASR model.
|
||||
The model has control over the construction, as long as it
|
||||
returns a valid PromptType."""
|
||||
|
||||
@@ -5,7 +5,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from math import ceil
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
@@ -455,8 +455,10 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str], task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio = Audio(audio, int(stt_config.sample_rate),
|
||||
format="wav") # lossless
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, TypedDict, Union, cast
|
||||
from typing import Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -783,8 +783,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
model_config: ModelConfig, # not needed here
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str],
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
if language is None:
|
||||
raise ValueError(
|
||||
"Language must be specified when creating the Whisper prompt")
|
||||
|
||||
Reference in New Issue
Block a user