diff --git a/docs/models/pooling_models/token_classify.md b/docs/models/pooling_models/token_classify.md index d669a716f..201ce4ea6 100644 --- a/docs/models/pooling_models/token_classify.md +++ b/docs/models/pooling_models/token_classify.md @@ -15,7 +15,7 @@ Many classification models support both (sequence) classification and token clas !!! note - Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task (classify) is not + Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task (classify) is not what you want, you need to manually specify it via `PoolerConfig(task="token_classify")` offline or `--pooler-config.task token_classify` online. @@ -29,6 +29,12 @@ Offline: [examples/pooling/token_classify/ner_offline.py](../../../examples/pool Online: [examples/pooling/token_classify/ner_online.py](../../../examples/pooling/token_classify/ner_online.py) +### Forced Alignment + +Forced alignment takes audio and reference text as input and produces word-level timestamps. + +Offline: [examples/pooling/token_classify/forced_alignment_offline.py](../../../examples/pooling/token_classify/forced_alignment_offline.py) + ### Sparse retrieval (lexical matching) The BAAI/bge-m3 model leverages token classification for sparse retrieval. For more information, see [this page](specific_models.md#baaibge-m3). @@ -43,12 +49,25 @@ The BAAI/bge-m3 model leverages token classification for sparse retrieval. For m | `Qwen3ForTokenClassification`C | Qwen3-based | `bd2lcco/Qwen3-0.6B-finetuned` | | | | `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | -C Automatically converted into a classification model via `--convert classify`. ([details](./README.md#model-conversion)) +C Automatically converted into a classification model via `--convert classify`. ([details](./README.md#model-conversion)) \* Feature support is the same as that of the original model. If your model is not in the above list, we will try to automatically convert the model using [as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. +### Multimodal Models + +!!! note + For more information about multimodal models inputs, see [this page](../supported_models.md#list-of-multimodal-language-models). + +| Architecture | Models | Inputs | Example HF Models | [LoRA](../../features/lora.md) | [PP](../../serving/parallelism_scaling.md) | +| --------------------------------------------- | ------------------- | ----------------- | ------------------------------------------ | ------------------------------ | ------------------------------------------ | +| `Qwen3ASRForcedAlignerForTokenClassification` | Qwen3-ForcedAligner | T + A+ | `Qwen/Qwen3-ForcedAligner-0.6B` (see note) | | ✅︎ | + +!!! note + Forced alignment usage requires `--hf-overrides '{"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]}'`. + Please refer to [examples/pooling/token_classify/forced_alignment_offline.py](../../../examples/pooling/token_classify/forced_alignment_offline.py). + ### As Reward Models Using token classification models as reward models. For details on reward models, see [Reward Models](reward.md). diff --git a/examples/pooling/token_classify/forced_alignment_offline.py b/examples/pooling/token_classify/forced_alignment_offline.py new file mode 100644 index 000000000..e97a37357 --- /dev/null +++ b/examples/pooling/token_classify/forced_alignment_offline.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from Qwen3-ForcedAligner inference: +# https://github.com/QwenLM/Qwen3-ASR + +""" +Offline forced alignment example using Qwen3-ForcedAligner-0.6B. + +Forced alignment takes audio and reference text as input and produces +word-level timestamps. The model predicts a time bin at each +token position; multiplying by ``timestamp_segment_time`` gives milliseconds. + +Usage:: + + python forced_alignment_offline.py \ + --model Qwen/Qwen3-ForcedAligner-0.6B +""" + +from argparse import Namespace + +import numpy as np + +from vllm import LLM, EngineArgs +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + parser.set_defaults( + model="Qwen/Qwen3-ForcedAligner-0.6B", + runner="pooling", + enforce_eager=True, + hf_overrides={"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]}, + ) + return parser.parse_args() + + +def build_prompt(words: list[str]) -> str: + """Build the forced alignment prompt from a word list. + + Format: <|audio_start|><|audio_pad|><|audio_end|> + word1word2... + """ + body = "".join(words) + "" + return f"<|audio_start|><|audio_pad|><|audio_end|>{body}" + + +def main(args: Namespace): + llm = LLM(**vars(args)) + + config = llm.llm_engine.vllm_config.model_config.hf_config + timestamp_token_id = config.timestamp_token_id + timestamp_segment_time = config.timestamp_segment_time + + # Example: align these words against a 5-second audio clip + words = ["Hello", "world"] + prompt = build_prompt(words) + + # Use a 5-second silent audio as placeholder (replace with real audio) + sample_rate = 16000 + audio = np.zeros(sample_rate * 5, dtype=np.float32) + + outputs = llm.encode( + [{"prompt": prompt, "multi_modal_data": {"audio": audio}}], + pooling_task="token_classify", + ) + + for output in outputs: + logits = output.outputs.data # [num_tokens, classify_num] + predictions = logits.argmax(dim=-1) + token_ids = output.prompt_token_ids + + # Extract timestamps at positions + ts_predictions = [ + pred.item() * timestamp_segment_time + for tid, pred in zip(token_ids, predictions) + if tid == timestamp_token_id + ] + + # Pair up start/end times per word + for i, word in enumerate(words): + start_ms = ts_predictions[i * 2] + end_ms = ts_predictions[i * 2 + 1] + print(f"{word:15s} {start_ms / 1000:.3f}s - {end_ms / 1000:.3f}s") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tests/models/multimodal/pooling/test_qwen3_asr_forced_aligner.py b/tests/models/multimodal/pooling/test_qwen3_asr_forced_aligner.py new file mode 100644 index 000000000..b29d0d585 --- /dev/null +++ b/tests/models/multimodal/pooling/test_qwen3_asr_forced_aligner.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import pytest +import torch + +MODEL = "Qwen/Qwen3-ForcedAligner-0.6B" +CLASSIFY_NUM = 5000 +TIMESTAMP_TOKEN_ID = 151705 + + +def build_prompt(words: list[str]) -> str: + body = "".join(words) + "" + return f"<|audio_start|><|audio_pad|><|audio_end|>{body}" + + +@pytest.mark.parametrize("model", [MODEL]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@torch.inference_mode() +def test_qwen3_forced_aligner( + vllm_runner, + model: str, + dtype: str, +) -> None: + words = ["Hello", "world"] + prompt = build_prompt(words) + + # 5-second silent audio at 16kHz + audio = np.zeros(16000 * 5, dtype=np.float32) + + with vllm_runner( + model, + runner="pooling", + dtype=dtype, + enforce_eager=True, + max_model_len=512, + hf_overrides={ + "architectures": [ + "Qwen3ASRForcedAlignerForTokenClassification", + ], + }, + ) as vllm_model: + outputs = vllm_model.llm.encode( + [{"prompt": prompt, "multi_modal_data": {"audio": audio}}], + pooling_task="token_classify", + ) + + # Validate output structure + assert len(outputs) == 1 + logits = outputs[0].outputs.data + assert logits.dim() == 2 + assert logits.shape[1] == CLASSIFY_NUM + + # Validate timestamp extraction + token_ids = outputs[0].prompt_token_ids + predictions = logits.argmax(dim=-1) + ts_indices = [i for i, t in enumerate(token_ids) if t == TIMESTAMP_TOKEN_ID] + + # 2 words x 2 timestamps each (start + end) = 4 + assert len(ts_indices) == 4 + + ts_preds = [predictions[i].item() for i in ts_indices] + assert all(p >= 0 for p in ts_preds) + # end >= start for each word + assert ts_preds[1] >= ts_preds[0] # Hello + assert ts_preds[3] >= ts_preds[2] # world diff --git a/tests/models/registry.py b/tests/models/registry.py index 266fbfaf1..4e13b49e4 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1094,6 +1094,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { min_transformers_version="4.57", hf_overrides={"architectures": ["Qwen3ASRRealtimeGeneration"]}, ), + "Qwen3ASRForcedAlignerForTokenClassification": _HfExamplesInfo( + "Qwen/Qwen3-ForcedAligner-0.6B", + max_model_len=4096, + min_transformers_version="4.57", + hf_overrides={"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]}, + ), "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True), "SkyworkR1VChatModel": _HfExamplesInfo( "Skywork/Skywork-R1V-38B", trust_remote_code=True diff --git a/vllm/model_executor/models/qwen3_asr_forced_aligner.py b/vllm/model_executor/models/qwen3_asr_forced_aligner.py new file mode 100644 index 000000000..56c57f477 --- /dev/null +++ b/vllm/model_executor/models/qwen3_asr_forced_aligner.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3-ASR ForcedAligner model (token classification).""" + +from collections.abc import Iterable + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify +from vllm.model_executor.models.interfaces_base import default_pooling_type +from vllm.model_executor.models.qwen3_asr import ( + Qwen3ASRDummyInputsBuilder, + Qwen3ASRForConditionalGeneration, + Qwen3ASRMultiModalProcessor, + Qwen3ASRProcessingInfo, +) +from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + + +@default_pooling_type(tok_pooling_type="ALL") +@MULTIMODAL_REGISTRY.register_processor( + Qwen3ASRMultiModalProcessor, + info=Qwen3ASRProcessingInfo, + dummy_inputs=Qwen3ASRDummyInputsBuilder, +) +class Qwen3ASRForcedAlignerForTokenClassification( + Qwen3ASRForConditionalGeneration, +): + """Qwen3-ASR Forced Aligner model for per-token timestamp classification. + + This model shares the audio tower and language model backbone with + Qwen3-ASR, but replaces the LM head with a classification head that + predicts time bins at ```` token positions. + + Usage:: + + llm = LLM( + model="Qwen/Qwen3-ForcedAligner-0.6B", + runner="pooling", + hf_overrides={ + "architectures": ["Qwen3ASRForcedAlignerForTokenClassification"] + }, + ) + outputs = llm.encode( + [{"prompt": prompt, "multi_modal_data": {"audio": audio}}], + pooling_task="token_classify", + ) + """ + + is_pooling_model = True + + # Map thinker.lm_head -> classifier (not language_model.lm_head) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "thinker.lm_head.": "classifier.", + "thinker.model.": "language_model.model.", + "thinker.": "", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + thinker_config = config.thinker_config + + # Remove the unused generation head created by the base class; + # the forced aligner uses a classifier head instead. + self.language_model.lm_head = None + self.language_model.logits_processor = None + + self.classify_num = thinker_config.classify_num + + # Classification head replaces lm_head for time-bin prediction. + # Use model dtype (not head_dtype which defaults to float32 for + # pooling models) to match the hidden state dtype. + self.classifier = nn.Linear( + thinker_config.text_config.hidden_size, + self.classify_num, + bias=False, + dtype=vllm_config.model_config.dtype, + ) + + # Token-level pooler to split per-token logits per request + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler = pooler_for_token_classify(pooler_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + inputs_embeds = None + + # Run through language model backbone (transformer layers only) + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + # Apply classification head -> [num_tokens, classify_num] + return self.classifier(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["talker.", "code2wav."], + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 839aba11c..2c72c5d68 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -292,6 +292,10 @@ _TOKEN_CLASSIFICATION_MODELS = { "modernbert", "ModernBertForTokenClassification", ), + "Qwen3ASRForcedAlignerForTokenClassification": ( + "qwen3_asr_forced_aligner", + "Qwen3ASRForcedAlignerForTokenClassification", + ), } _SEQUENCE_CLASSIFICATION_MODELS = { diff --git a/vllm/transformers_utils/configs/qwen3_asr.py b/vllm/transformers_utils/configs/qwen3_asr.py index a08b2b7de..b5b35c833 100644 --- a/vllm/transformers_utils/configs/qwen3_asr.py +++ b/vllm/transformers_utils/configs/qwen3_asr.py @@ -342,12 +342,14 @@ class Qwen3ASRThinkerConfig(PretrainedConfig): audio_start_token_id=151647, user_token_id=872, initializer_range=0.02, + classify_num=None, **kwargs, ): super().__init__(**kwargs) self.user_token_id = user_token_id self.audio_start_token_id = audio_start_token_id self.initializer_range = initializer_range + self.classify_num = classify_num if isinstance(audio_config, dict): audio_config = Qwen3ASRAudioEncoderConfig(**audio_config) @@ -406,6 +408,8 @@ class Qwen3ASRConfig(PretrainedConfig): self, thinker_config=None, support_languages=None, + timestamp_token_id=None, + timestamp_segment_time=None, **kwargs, ): if thinker_config is None: @@ -416,6 +420,8 @@ class Qwen3ASRConfig(PretrainedConfig): self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config) self.support_languages = support_languages + self.timestamp_token_id = timestamp_token_id + self.timestamp_segment_time = timestamp_segment_time super().__init__(**kwargs) def get_text_config(self, decoder=False) -> "PretrainedConfig":