[Model] Add transcription support for Qwen3-Omni (#29828)

Signed-off-by: Muhammad Hashmi <mhashmi@berkeley.edu>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Muhammad Hashmi
2026-02-04 13:17:47 -08:00
committed by GitHub
parent 4292c90a2a
commit 535de06cb1
3 changed files with 104 additions and 2 deletions

View File

@@ -24,7 +24,7 @@
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial
from typing import Any
from typing import Any, Literal, cast
import numpy as np
import torch
@@ -48,8 +48,9 @@ from transformers import __version__ as TRANSFORMERS_VERSION
# isort: on
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.attention.mm_encoder_attention import (
@@ -79,6 +80,7 @@ from vllm.multimodal.processing.processor import (
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interfaces import (
@@ -86,6 +88,7 @@ from .interfaces import (
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
SupportsTranscription,
)
from .qwen2_5_omni_thinker import (
Qwen2_5OmniAudioFeatureInputs,
@@ -110,6 +113,29 @@ from .vision import get_vit_attn_backend
logger = init_logger(__name__)
# Speech input languages supported by Qwen3-Omni
# From: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct
ISO639_1_SUPPORTED_LANGS = {
"en": "English",
"zh": "Chinese",
"ko": "Korean",
"ja": "Japanese",
"de": "German",
"ru": "Russian",
"it": "Italian",
"fr": "French",
"es": "Spanish",
"pt": "Portuguese",
"ms": "Malay",
"nl": "Dutch",
"id": "Indonesian",
"tr": "Turkish",
"vi": "Vietnamese",
"yue": "Cantonese",
"ar": "Arabic",
"ur": "Urdu",
}
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
input_lengths_leave = input_lengths % 100
@@ -1572,6 +1598,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
SupportsPP,
SupportsMRoPE,
Qwen3OmniMoeConditionalGenerationMixin,
SupportsTranscription,
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@@ -1593,6 +1620,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
],
}
supported_languages = ISO639_1_SUPPORTED_LANGS
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
@@ -2085,6 +2114,77 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
total_tokens = num_video + audio_len
return np.concatenate(pos_ids_list, axis=1), total_tokens
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
processor = cached_processor_from_config(
model_config, processor_cls=Qwen3OmniMoeProcessor
)
return SpeechToTextConfig(
max_audio_clip_s=processor.feature_extractor.chunk_length,
sample_rate=processor.feature_extractor.sampling_rate,
min_energy_split_window_size=None,
)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""
Construct a transcription/translation prompt for Qwen3-Omni.
"""
# Transcribe this audio [into <language>] | for transcription
# Translate this audio [from <language> into <to_language>] | for translation
instruction = "Transcribe" if task_type == "transcribe" else "Translate"
instruction += " this audio"
# Default to_language to English for translation
if task_type == "translate" and to_language is None:
to_language = "en"
# Get full language names from supported_languages mapping
full_lang_name = cls.supported_languages.get(language, "")
full_lang_name_to = cls.supported_languages.get(to_language, "")
if task_type == "transcribe" and full_lang_name:
instruction += f" into {full_lang_name}"
elif task_type == "translate":
if full_lang_name:
instruction += f" from {full_lang_name}"
if full_lang_name_to:
instruction += f" into {full_lang_name_to}"
instruction += "."
if request_prompt:
instruction += f" {request_prompt}"
processor = cached_processor_from_config(
model_config, processor_cls=Qwen3OmniMoeProcessor
)
# Audio placeholder format: <|audio_start|><|audio_pad|><|audio_end|>
audio_placeholder = "<|audio_start|><|audio_pad|><|audio_end|>"
user_content = f"{audio_placeholder}{instruction}"
messages = [{"role": "user", "content": user_content}]
prompt = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
audio_data = (audio, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio_data}, "prompt": prompt}
return cast(PromptType, prompts_dict)
def get_mrope_input_positions(
self,
input_tokens: list[int],