[Model] Add transcription support for Qwen3-Omni (#29828)
Signed-off-by: Muhammad Hashmi <mhashmi@berkeley.edu> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -24,7 +24,7 @@
|
||||
|
||||
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -48,8 +48,9 @@ from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
# isort: on
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.attention.mm_encoder_attention import (
|
||||
@@ -79,6 +80,7 @@ from vllm.multimodal.processing.processor import (
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
from .interfaces import (
|
||||
@@ -86,6 +88,7 @@ from .interfaces import (
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsTranscription,
|
||||
)
|
||||
from .qwen2_5_omni_thinker import (
|
||||
Qwen2_5OmniAudioFeatureInputs,
|
||||
@@ -110,6 +113,29 @@ from .vision import get_vit_attn_backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Speech input languages supported by Qwen3-Omni
|
||||
# From: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"en": "English",
|
||||
"zh": "Chinese",
|
||||
"ko": "Korean",
|
||||
"ja": "Japanese",
|
||||
"de": "German",
|
||||
"ru": "Russian",
|
||||
"it": "Italian",
|
||||
"fr": "French",
|
||||
"es": "Spanish",
|
||||
"pt": "Portuguese",
|
||||
"ms": "Malay",
|
||||
"nl": "Dutch",
|
||||
"id": "Indonesian",
|
||||
"tr": "Turkish",
|
||||
"vi": "Vietnamese",
|
||||
"yue": "Cantonese",
|
||||
"ar": "Arabic",
|
||||
"ur": "Urdu",
|
||||
}
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
input_lengths_leave = input_lengths % 100
|
||||
@@ -1572,6 +1598,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
SupportsPP,
|
||||
SupportsMRoPE,
|
||||
Qwen3OmniMoeConditionalGenerationMixin,
|
||||
SupportsTranscription,
|
||||
):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@@ -1593,6 +1620,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
],
|
||||
}
|
||||
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
@@ -2085,6 +2114,77 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
total_tokens = num_video + audio_len
|
||||
return np.concatenate(pos_ids_list, axis=1), total_tokens
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
) -> SpeechToTextConfig:
|
||||
processor = cached_processor_from_config(
|
||||
model_config, processor_cls=Qwen3OmniMoeProcessor
|
||||
)
|
||||
return SpeechToTextConfig(
|
||||
max_audio_clip_s=processor.feature_extractor.chunk_length,
|
||||
sample_rate=processor.feature_extractor.sampling_rate,
|
||||
min_energy_split_window_size=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: str | None,
|
||||
) -> PromptType:
|
||||
"""
|
||||
Construct a transcription/translation prompt for Qwen3-Omni.
|
||||
"""
|
||||
# Transcribe this audio [into <language>] | for transcription
|
||||
# Translate this audio [from <language> into <to_language>] | for translation
|
||||
instruction = "Transcribe" if task_type == "transcribe" else "Translate"
|
||||
instruction += " this audio"
|
||||
|
||||
# Default to_language to English for translation
|
||||
if task_type == "translate" and to_language is None:
|
||||
to_language = "en"
|
||||
|
||||
# Get full language names from supported_languages mapping
|
||||
full_lang_name = cls.supported_languages.get(language, "")
|
||||
full_lang_name_to = cls.supported_languages.get(to_language, "")
|
||||
|
||||
if task_type == "transcribe" and full_lang_name:
|
||||
instruction += f" into {full_lang_name}"
|
||||
elif task_type == "translate":
|
||||
if full_lang_name:
|
||||
instruction += f" from {full_lang_name}"
|
||||
if full_lang_name_to:
|
||||
instruction += f" into {full_lang_name_to}"
|
||||
|
||||
instruction += "."
|
||||
|
||||
if request_prompt:
|
||||
instruction += f" {request_prompt}"
|
||||
|
||||
processor = cached_processor_from_config(
|
||||
model_config, processor_cls=Qwen3OmniMoeProcessor
|
||||
)
|
||||
# Audio placeholder format: <|audio_start|><|audio_pad|><|audio_end|>
|
||||
audio_placeholder = "<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
user_content = f"{audio_placeholder}{instruction}"
|
||||
|
||||
messages = [{"role": "user", "content": user_content}]
|
||||
prompt = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
audio_data = (audio, stt_config.sample_rate)
|
||||
prompts_dict = {"multi_modal_data": {"audio": audio_data}, "prompt": prompt}
|
||||
return cast(PromptType, prompts_dict)
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
|
||||
Reference in New Issue
Block a user