@@ -720,6 +720,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ |
|
||||
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | T + A<sup>+</sup> | `Qwen/Qwen3-ASR-1.7B` | ✅︎ | ✅︎ |
|
||||
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | |
|
||||
@@ -769,6 +770,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
|
||||
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-speech-3.3-2b`, `ibm-granite/granite-speech-3.3-8b`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | | ✅︎ |
|
||||
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
|
||||
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | |
|
||||
|
||||
|
||||
@@ -330,6 +330,25 @@ def run_qwen2_5_omni(question: str, audio_count: int):
|
||||
)
|
||||
|
||||
|
||||
def run_qwen3_asr(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen3-Asr-1.7B"
|
||||
|
||||
audio_in_prompt = "<|audio_start|><|audio_pad|><|audio_end|>\n" * audio_count
|
||||
prompt = f"<|im_start|>user\n{audio_in_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Ultravox 0.5-1B
|
||||
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
@@ -442,6 +461,7 @@ model_example_map = {
|
||||
"phi4_mm": run_phi4mm,
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"qwen3_asr": run_qwen3_asr,
|
||||
"ultravox": run_ultravox,
|
||||
"voxtral": run_voxtral,
|
||||
"whisper": run_whisper,
|
||||
|
||||
@@ -944,6 +944,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57",
|
||||
),
|
||||
"Qwen3ASRForConditionalGeneration": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-ASR-1.7B",
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57",
|
||||
is_available_online=False,
|
||||
),
|
||||
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True),
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo(
|
||||
"Skywork/Skywork-R1V-38B", trust_remote_code=True
|
||||
|
||||
567
vllm/model_executor/models/qwen3_asr.py
Normal file
567
vllm/model_executor/models/qwen3_asr.py
Normal file
@@ -0,0 +1,567 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2026 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3-ASR model."""
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsTranscription,
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
||||
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
|
||||
Qwen2_5OmniAudioFeatureInputs,
|
||||
Qwen3OmniMoeAudioEncoder,
|
||||
Qwen3OmniMoeThinkerMultiModalProcessor,
|
||||
)
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
_merge_multimodal_embeddings,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
AudioItem,
|
||||
ModalityData,
|
||||
MultiModalDataDict,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
AudioProcessorItems,
|
||||
DictEmbeddingItems,
|
||||
ModalityDataItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.transformers_utils.configs.qwen3_asr import (
|
||||
Qwen3ASRConfig,
|
||||
Qwen3ASRThinkerConfig,
|
||||
)
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.transformers_utils.processors.qwen3_asr import (
|
||||
Qwen3ASRProcessor,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths
|
||||
|
||||
|
||||
class Qwen3ASRProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Qwen3ASRConfig).thinker_config
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> Qwen3ASRProcessor:
|
||||
processor = self.ctx.get_hf_processor(
|
||||
Qwen3ASRProcessor,
|
||||
use_fast=kwargs.pop("use_fast", True),
|
||||
**kwargs,
|
||||
)
|
||||
if not hasattr(processor, "audio_token"):
|
||||
processor.audio_token = "<|audio_pad|>"
|
||||
return processor
|
||||
|
||||
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
|
||||
hf_processor = self.get_hf_processor(**kwargs)
|
||||
feature_extractor = hf_processor.feature_extractor
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None}
|
||||
|
||||
|
||||
class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
audio_token = hf_processor.audio_token
|
||||
|
||||
return audio_token * num_audios
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
|
||||
target_audio_length = (
|
||||
min(
|
||||
feature_extractor.chunk_length,
|
||||
30,
|
||||
)
|
||||
* feature_extractor.sampling_rate
|
||||
)
|
||||
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio": self._get_dummy_audios(
|
||||
length=target_audio_length,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _qwen3asr_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_feature_lengths = hf_inputs.get("audio_feature_lengths", torch.empty((0,)))
|
||||
return dict(
|
||||
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_feature_lengths, dim=1
|
||||
),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
|
||||
class Qwen3ASRMultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="audio",
|
||||
required_fields={"input_audio_features", "audio_feature_lengths"},
|
||||
fields_factory=_qwen3asr_field_config,
|
||||
)
|
||||
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class Qwen3ASRMultiModalProcessor(
|
||||
Qwen3OmniMoeThinkerMultiModalProcessor,
|
||||
):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return Qwen3ASRMultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _qwen3asr_field_config(hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
audio_token = processor.audio_token
|
||||
audio_token_id = vocab[audio_token]
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
audio_feature_lengths = out_mm_data.get("audio_feature_lengths")
|
||||
feature_attention_mask = out_mm_data.get("feature_attention_mask")
|
||||
if audio_feature_lengths is None and feature_attention_mask is None:
|
||||
audio_output_lengths = []
|
||||
elif audio_feature_lengths is not None:
|
||||
audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
|
||||
audio_output_lengths = audio_output_lens.tolist()
|
||||
elif feature_attention_mask is not None:
|
||||
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||
audio_output_lens = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1)
|
||||
)
|
||||
audio_output_lengths = audio_output_lens.tolist()
|
||||
|
||||
def get_replacement_qwen2_audio(item_idx: int):
|
||||
num_features = audio_output_lengths[item_idx]
|
||||
if num_features == 0:
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
audio = audios.get(item_idx)
|
||||
raise ValueError(
|
||||
f"The audio {audio} (len={len(audio)}) is too short "
|
||||
"to be represented inside the model"
|
||||
)
|
||||
|
||||
return [audio_token_id] * num_features
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=audio_token,
|
||||
replacement=get_replacement_qwen2_audio,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen3ASRMultiModalProcessor,
|
||||
info=Qwen3ASRProcessingInfo,
|
||||
dummy_inputs=Qwen3ASRDummyInputsBuilder,
|
||||
)
|
||||
class Qwen3ASRForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsMRoPE,
|
||||
SupportsTranscription,
|
||||
):
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"thinker.lm_head.": "language_model.lm_head.",
|
||||
"thinker.model.": "language_model.model.",
|
||||
"thinker.": "",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("audio"):
|
||||
return "<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config # needed for torch compile forward context
|
||||
thinker_config: Qwen3ASRThinkerConfig = (
|
||||
vllm_config.model_config.hf_config.thinker_config
|
||||
)
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = thinker_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.audio_tower = Qwen3OmniMoeAudioEncoder(
|
||||
thinker_config.audio_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.language_model = Qwen3ForCausalLM(
|
||||
vllm_config=vllm_config.with_hf_config(
|
||||
thinker_config.text_config, architectures=["Qwen3ForCausalLM"]
|
||||
),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2_5OmniAudioFeatureInputs | None:
|
||||
input_audio_features = kwargs.pop("input_audio_features", None)
|
||||
audio_feature_lengths = kwargs.pop("audio_feature_lengths", None)
|
||||
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
|
||||
if input_audio_features is None:
|
||||
return None
|
||||
|
||||
return Qwen2_5OmniAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=input_audio_features,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
mm_input_by_modality = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if (
|
||||
input_key in ("input_audio_features")
|
||||
and "audio" not in mm_input_by_modality
|
||||
):
|
||||
mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
|
||||
**kwargs
|
||||
)
|
||||
return mm_input_by_modality
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
||||
audio_hashes: list[str] | None = None,
|
||||
cached_audio_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_features = audio_input["input_features"]
|
||||
audio_feature_lengths = audio_input["audio_feature_lengths"]
|
||||
|
||||
audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
|
||||
|
||||
audio_features = self.audio_tower(
|
||||
input_features.to(self.audio_tower.dtype),
|
||||
feature_lens=audio_feature_lengths,
|
||||
aftercnn_lens=audio_output_lengths,
|
||||
)
|
||||
return audio_features.split(audio_output_lengths.tolist())
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not mm_input_by_modality:
|
||||
return []
|
||||
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image or video).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in mm_input_by_modality:
|
||||
multimodal_input = mm_input_by_modality[modality]
|
||||
if modality == "audio":
|
||||
audio_embeddings = self._process_audio_input(multimodal_input)
|
||||
multimodal_embeddings += tuple(audio_embeddings)
|
||||
return multimodal_embeddings
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.language_model.embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
inputs_embeds = _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=["talker.", "code2wav."],
|
||||
)
|
||||
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
return loaded_weights
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
mm_features: list[MultiModalFeatureSpec],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
seq_len = len(input_tokens)
|
||||
|
||||
if not mm_features:
|
||||
# No audio features, just return linear positions
|
||||
llm_positions = (
|
||||
torch.arange(seq_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
)
|
||||
return llm_positions.clone(), 0
|
||||
|
||||
llm_pos_ids_list: list[torch.Tensor] = []
|
||||
st = 0
|
||||
|
||||
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
|
||||
offset = mm_feature.mm_position.offset
|
||||
|
||||
# Get audio feature length from mm_feature data
|
||||
audio_feature_length = mm_feature.data["audio_feature_lengths"].data
|
||||
if isinstance(audio_feature_length, torch.Tensor):
|
||||
audio_feature_length = audio_feature_length.item()
|
||||
audio_len = _get_feat_extract_output_lengths(
|
||||
torch.tensor(audio_feature_length)
|
||||
).item()
|
||||
|
||||
# Text segment before audio (includes audio_start token)
|
||||
text_len = offset - st
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
text_positions = (
|
||||
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(text_positions)
|
||||
st_idx = st_idx + text_len
|
||||
|
||||
# Audio token segment
|
||||
audio_positions = (
|
||||
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(audio_positions)
|
||||
|
||||
st = offset + audio_len
|
||||
|
||||
# Handle remaining text (includes audio_end and any trailing text)
|
||||
if st < seq_len:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
text_len = seq_len - st
|
||||
final_text_positions = (
|
||||
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(final_text_positions)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
if llm_positions.shape[1] != seq_len:
|
||||
raise RuntimeError("Position ids length mismatch with input ids length")
|
||||
|
||||
mrope_position_delta = (llm_positions.max() + 1 - seq_len).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
tower_model=["audio_tower."],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
) -> SpeechToTextConfig:
|
||||
processor = cached_processor_from_config(model_config)
|
||||
feature_extractor: WhisperFeatureExtractor = processor.feature_extractor
|
||||
return SpeechToTextConfig(
|
||||
max_audio_clip_s=feature_extractor.chunk_length,
|
||||
sample_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: str | None,
|
||||
) -> PromptType:
|
||||
"""Get the generation prompt to be used for transcription requests."""
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio_placeholder = cls.get_placeholder_str("audio", 0)
|
||||
|
||||
if task_type not in ("transcribe", "translate"):
|
||||
raise ValueError(
|
||||
f"Unsupported task_type '{task_type}'. "
|
||||
"Supported task types are 'transcribe' and 'translate'."
|
||||
)
|
||||
full_lang_name_to = cls.supported_languages.get(to_language, to_language)
|
||||
if to_language is None:
|
||||
prompt = (
|
||||
f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
prompt = (
|
||||
f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\nlanguage {full_lang_name_to}<asr_text>"
|
||||
)
|
||||
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
prompt_dict = {
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": {"audio": audio},
|
||||
}
|
||||
return cast(PromptType, prompt_dict)
|
||||
@@ -436,6 +436,10 @@ _MULTIMODAL_MODELS = {
|
||||
"qwen3_omni_moe_thinker",
|
||||
"Qwen3OmniMoeThinkerForConditionalGeneration",
|
||||
),
|
||||
"Qwen3ASRForConditionalGeneration": (
|
||||
"qwen3_asr",
|
||||
"Qwen3ASRForConditionalGeneration",
|
||||
),
|
||||
"Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen3VLMoeForConditionalGeneration": (
|
||||
"qwen3_vl_moe",
|
||||
|
||||
@@ -97,6 +97,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
ultravox="UltravoxConfig",
|
||||
step3_vl="Step3VLConfig",
|
||||
step3_text="Step3TextConfig",
|
||||
qwen3_asr="Qwen3ASRConfig",
|
||||
qwen3_next="Qwen3NextConfig",
|
||||
lfm2_moe="Lfm2MoeConfig",
|
||||
tarsier2="Tarsier2Config",
|
||||
|
||||
@@ -52,6 +52,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
|
||||
"Step3VLConfig": "vllm.transformers_utils.configs.step3_vl",
|
||||
"Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl",
|
||||
"Step3TextConfig": "vllm.transformers_utils.configs.step3_vl",
|
||||
"Qwen3ASRConfig": "vllm.transformers_utils.configs.qwen3_asr",
|
||||
"Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next",
|
||||
"Tarsier2Config": "vllm.transformers_utils.configs.tarsier2",
|
||||
# Special case: DeepseekV3Config is from HuggingFace Transformers
|
||||
@@ -94,6 +95,7 @@ __all__ = [
|
||||
"Step3VLConfig",
|
||||
"Step3VisionEncoderConfig",
|
||||
"Step3TextConfig",
|
||||
"Qwen3ASRConfig",
|
||||
"Qwen3NextConfig",
|
||||
"Tarsier2Config",
|
||||
]
|
||||
|
||||
436
vllm/transformers_utils/configs/qwen3_asr.py
Normal file
436
vllm/transformers_utils/configs/qwen3_asr.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
# mypy: ignore-errors
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3ASRAudioEncoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3ASRAudioEncoder`]. It is used to instantiate a
|
||||
Qwen3-ASR audio encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the audio encoder of the Qwen2-Audio
|
||||
architecture.
|
||||
|
||||
e.g. [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
num_mel_bins (`int`, *optional*, defaults to 128):
|
||||
Number of mel features used per input features. Should correspond to the value used in the
|
||||
`Qwen3ASRProcessor` class.
|
||||
encoder_layers (`int`, *optional*, defaults to 32):
|
||||
Number of encoder layers.
|
||||
encoder_attention_heads (`int`, *optional*, defaults to 20):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
encoder_ffn_dim (`int`, *optional*, defaults to 5120):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
|
||||
d_model (`int`, *optional*, defaults to 1280):
|
||||
Dimensionality of the layers.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
max_source_positions (`int`, *optional*, defaults to 1500):
|
||||
The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
|
||||
n_window (`int`, *optional*, defaults to 100):
|
||||
The chunk for conv and flash attn in AudioEncoder.
|
||||
output_dim (`int`, *optional*, defaults to 3584):
|
||||
The output dimension of AudioEncoder.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3ASRAudioEncoderConfig, Qwen3ASRAudioEncoder
|
||||
|
||||
>>> # Initializing a Qwen3ASRAudioEncoderConfig
|
||||
>>> configuration = Qwen3ASRAudioEncoderConfig()
|
||||
|
||||
>>> # Initializing a Qwen3ASRAudioEncoder (with random weights)
|
||||
>>> model = Qwen3ASRAudioEncoder(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_asr_audio_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_mel_bins=128,
|
||||
encoder_layers=32,
|
||||
encoder_attention_heads=20,
|
||||
encoder_ffn_dim=5120,
|
||||
d_model=1280,
|
||||
dropout=0,
|
||||
attention_dropout=0,
|
||||
activation_function="gelu",
|
||||
activation_dropout=0,
|
||||
scale_embedding=False,
|
||||
initializer_range=0.02,
|
||||
max_source_positions=1500,
|
||||
n_window=100,
|
||||
output_dim=3584,
|
||||
n_window_infer=400,
|
||||
conv_chunksize=500,
|
||||
downsample_hidden_size=480,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.d_model = d_model
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_function = activation_function
|
||||
self.activation_dropout = activation_dropout
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.initializer_range = initializer_range
|
||||
self.scale_embedding = (
|
||||
scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
)
|
||||
self.max_source_positions = max_source_positions
|
||||
self.n_window = n_window
|
||||
self.output_dim = output_dim
|
||||
self.n_window_infer = n_window_infer
|
||||
self.conv_chunksize = conv_chunksize
|
||||
self.downsample_hidden_size = downsample_hidden_size
|
||||
|
||||
|
||||
class Qwen3ASRTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3ASRTextModel`]. It is used to instantiate a
|
||||
Qwen3-ASR model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen3-ASR-1.7B [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen3ASR model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen3ASRModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 22016):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 32):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 128000):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 5000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3ASRTextModel, Qwen3ASRTextConfig
|
||||
|
||||
>>> # Initializing a Qwen3ASR style configuration
|
||||
>>> configuration = Qwen3ASRTextConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen3-VL-7B style configuration
|
||||
>>> model = Qwen3ASRTextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_asr_text"
|
||||
base_config_key = "text_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=4096,
|
||||
intermediate_size=22016,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=128000,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=5000000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
class Qwen3ASRThinkerConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3ASRThinker`]. It is used to instantiate a
|
||||
Qwen3-ASR-Thinker model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the thinker component of the Qwen3-Omni
|
||||
architecture.
|
||||
|
||||
e.g. [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
audio_config (`dict`, *optional*):
|
||||
The config dictionary of the audio backbone.
|
||||
text_config (`dict`, *optional*):
|
||||
The config dictionary of the text backbone.
|
||||
audio_token_id (`int`, *optional*, defaults to 151646):
|
||||
The audio token id to encode the audio prompt.
|
||||
audio_start_token_id (`int`, *optional*, defaults to 151647):
|
||||
The audio start token id to encode the audio prompt.
|
||||
user_token_id (`int`, *optional*, defaults to 872):
|
||||
The user token id to encode the user token.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3ASRThinkerModel, Qwen3ASRThinkerConfig
|
||||
|
||||
>>> # Initializing a default Qwen3ASRThinkerConfig
|
||||
>>> configuration = Qwen3ASRThinkerConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the default configuration
|
||||
>>> model = Qwen3ASRThinkerModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_asr_thinker"
|
||||
|
||||
attribute_map = {}
|
||||
sub_configs = {
|
||||
"audio_config": Qwen3ASRAudioEncoderConfig,
|
||||
"text_config": Qwen3ASRTextConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_config=None,
|
||||
text_config=None,
|
||||
audio_token_id=151646,
|
||||
audio_start_token_id=151647,
|
||||
user_token_id=872,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.user_token_id = user_token_id
|
||||
self.audio_start_token_id = audio_start_token_id
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
if isinstance(audio_config, dict):
|
||||
audio_config = Qwen3ASRAudioEncoderConfig(**audio_config)
|
||||
elif audio_config is None:
|
||||
audio_config = Qwen3ASRAudioEncoderConfig()
|
||||
self.audio_config = audio_config
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config = Qwen3ASRTextConfig(**text_config)
|
||||
elif text_config is None:
|
||||
text_config = Qwen3ASRTextConfig()
|
||||
self.text_config = text_config
|
||||
self.audio_token_id = audio_token_id
|
||||
|
||||
|
||||
class Qwen3ASRConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3ASRForConditionalGeneration`]. It is used to instantiate a Qwen3ASR
|
||||
model according to the specified sub-models configurations, defining the model architecture.
|
||||
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
||||
[Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
thinker_config (`dict`, *optional*): Configuration of the underlying thinker sub-model.
|
||||
support_languages (`List[str]`, *optional*): The languages supported by the model.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... Qwen3ASRThinkerConfig,
|
||||
... Qwen3ASRForConditionalGeneration,
|
||||
... Qwen3ASRConfig,
|
||||
... )
|
||||
|
||||
>>> # Initializing a Qwen3ASR style configuration
|
||||
>>> configuration = Qwen3ASRConfig()
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = Qwen3ASRForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_asr"
|
||||
sub_configs = {
|
||||
"thinker_config": Qwen3ASRThinkerConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
thinker_config=None,
|
||||
support_languages=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if thinker_config is None:
|
||||
thinker_config = {}
|
||||
logger.info(
|
||||
"thinker_config is None. Initializing thinker model with default values"
|
||||
)
|
||||
|
||||
self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config)
|
||||
self.support_languages = support_languages
|
||||
|
||||
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
||||
"""
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
|
||||
Args:
|
||||
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||
If set to `True`, then only search for decoder config names.
|
||||
"""
|
||||
# Overridden for deeply nested config like Qwen2.5-Omni. We don't have any omni model
|
||||
# except for Qwen yet. This has to be generalized if more deeply nested configs are
|
||||
# added. NOTE: currently method used only by vLLM
|
||||
return self.thinker_config.get_text_config()
|
||||
|
||||
|
||||
__all__ = ["Qwen3ASRConfig", "Qwen3ASRThinkerConfig", "Qwen3ASRAudioEncoderConfig"]
|
||||
231
vllm/transformers_utils/processors/qwen3_asr.py
Normal file
231
vllm/transformers_utils/processors/qwen3_asr.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
# mypy: ignore-errors
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import regex as re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from transformers.audio_utils import AudioInput
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
|
||||
class Qwen3ASRProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
"padding_side": "left",
|
||||
},
|
||||
"audio_kwargs": {
|
||||
"sampling_rate": 16000,
|
||||
"padding": True,
|
||||
"return_attention_mask": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths):
|
||||
"""
|
||||
Computes the output length of the convolutional layers and the output length of the audio encoder
|
||||
"""
|
||||
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths
|
||||
|
||||
|
||||
class Qwen3ASRProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Qwen3ASR processor.
|
||||
[`Qwen3ASRProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and [`Qwen2TokenizerFast`]. See the
|
||||
[`~Qwen3ASRProcessor.__call__`] and [`~Qwen3ASRProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
|
||||
The audio feature extractor.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The text tokenizer.
|
||||
chat_template (`Optional[str]`, *optional*):
|
||||
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
|
||||
"""
|
||||
|
||||
attributes = ["feature_extractor", "tokenizer"]
|
||||
feature_extractor_class = "WhisperFeatureExtractor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None):
|
||||
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
|
||||
self.audio_token = self.tokenizer.audio_token
|
||||
self.audio_bos_token = self.tokenizer.audio_bos_token
|
||||
self.audio_eos_token = self.tokenizer.audio_eos_token
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput = None,
|
||||
audio: AudioInput = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
|
||||
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the audio(s), this method forwards the `audio` and `kwargs` arguments to
|
||||
WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] if `audio` is not `None`. Please refer to the doctsring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
audio (`np.ndarray`, `List[np.ndarray]`):
|
||||
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
|
||||
"""
|
||||
|
||||
if text is None:
|
||||
raise ValueError("You need to specify either a `text` input to process.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Qwen3ASRProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if audio is not None:
|
||||
output_kwargs["audio_kwargs"]["padding"] = True
|
||||
output_kwargs["audio_kwargs"]["truncation"] = False
|
||||
audio_inputs = self.feature_extractor(
|
||||
audio, **output_kwargs["audio_kwargs"]
|
||||
)
|
||||
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
|
||||
"attention_mask"
|
||||
) # rename feature_attention_mask to prevent conflicts later on
|
||||
audio_inputs["input_features"] = audio_inputs.pop(
|
||||
"input_features"
|
||||
) # rename input_features to prevent conflicts later on
|
||||
audio_lengths = iter(
|
||||
_get_feat_extract_output_lengths(
|
||||
audio_inputs["feature_attention_mask"].sum(-1)
|
||||
)
|
||||
)
|
||||
else:
|
||||
audio_inputs = {}
|
||||
audio_lengths = iter([])
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = self.replace_multimodal_special_tokens(
|
||||
text,
|
||||
audio_lengths,
|
||||
)
|
||||
|
||||
texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
return BatchFeature(
|
||||
data={**texts_inputs, **audio_inputs},
|
||||
tensor_type=kwargs.get("return_tensors"),
|
||||
)
|
||||
|
||||
def replace_multimodal_special_tokens(
|
||||
self,
|
||||
text,
|
||||
audio_lengths,
|
||||
):
|
||||
processed_text = []
|
||||
for sample in text:
|
||||
positions = []
|
||||
special_tokens = [re.escape(tok) for tok in [self.audio_token]]
|
||||
pattern = "|".join(special_tokens)
|
||||
positions = sorted(
|
||||
[
|
||||
(match.start(), match.group())
|
||||
for match in re.finditer(pattern, sample)
|
||||
]
|
||||
)
|
||||
positions.sort(key=lambda x: x[0])
|
||||
|
||||
for _, special_token in positions:
|
||||
if special_token == self.audio_token:
|
||||
sample = sample.replace(
|
||||
self.audio_token,
|
||||
"<|audio_placeholder|>" * next(audio_lengths),
|
||||
1,
|
||||
)
|
||||
|
||||
sample = sample.replace("<|audio_placeholder|>", self.audio_token)
|
||||
processed_text.append(sample)
|
||||
return processed_text
|
||||
|
||||
def get_chunked_index(
|
||||
self, token_indices: np.ndarray, tokens_per_chunk: int
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Splits token index list into chunks based on token value ranges.
|
||||
|
||||
Given a list of token indices, returns a list of (start, end) index tuples representing
|
||||
slices of the list where the token values fall within successive ranges of `tokens_per_chunk`.
|
||||
|
||||
For example, if `tokens_per_chunk` is 1000, the function will create chunks such that:
|
||||
- the first chunk contains token values < 1000,
|
||||
- the second chunk contains values >= 1000 and < 2000, and so on.
|
||||
|
||||
Parameters:
|
||||
token_indices (`np.ndarray`): A monotonically increasing list of token index values.
|
||||
tokens_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
|
||||
|
||||
Returns:
|
||||
`list[tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
|
||||
and end (exclusive) indices of a chunk in `token_indices`.
|
||||
"""
|
||||
|
||||
def _iter():
|
||||
i, start_idx = 0, 0 # skip bos token
|
||||
current_chunk = 1
|
||||
while i < len(token_indices): # skip eos token
|
||||
if token_indices[i] >= current_chunk * tokens_per_chunk:
|
||||
yield (start_idx, i)
|
||||
start_idx = i
|
||||
current_chunk += 1
|
||||
i += 1
|
||||
yield (start_idx, len(token_indices))
|
||||
|
||||
return list(_iter())
|
||||
|
||||
def apply_chat_template(self, conversations, chat_template=None, **kwargs):
|
||||
return super().apply_chat_template(conversations, chat_template, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
feature_extractor_input_names = self.feature_extractor.model_input_names
|
||||
return list(
|
||||
dict.fromkeys(
|
||||
tokenizer_input_names
|
||||
+ feature_extractor_input_names
|
||||
+ ["feature_attention_mask"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
AutoProcessor.register("Qwen3ASRProcessor", Qwen3ASRProcessor)
|
||||
Reference in New Issue
Block a user