[Feature] Add Qwen3-ForcedAligner support via token classification pooling (#35367)

Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
haosdent
2026-03-29 08:27:52 +08:00
committed by GitHub
parent fafca38adc
commit d39b8daf5f
7 changed files with 314 additions and 2 deletions

View File

@@ -15,7 +15,7 @@ Many classification models support both (sequence) classification and token clas
!!! note
Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task (classify) is not
Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task (classify) is not
what you want, you need to manually specify it via `PoolerConfig(task="token_classify")` offline or
`--pooler-config.task token_classify` online.
@@ -29,6 +29,12 @@ Offline: [examples/pooling/token_classify/ner_offline.py](../../../examples/pool
Online: [examples/pooling/token_classify/ner_online.py](../../../examples/pooling/token_classify/ner_online.py)
### Forced Alignment
Forced alignment takes audio and reference text as input and produces word-level timestamps.
Offline: [examples/pooling/token_classify/forced_alignment_offline.py](../../../examples/pooling/token_classify/forced_alignment_offline.py)
### Sparse retrieval (lexical matching)
The BAAI/bge-m3 model leverages token classification for sparse retrieval. For more information, see [this page](specific_models.md#baaibge-m3).
@@ -43,12 +49,25 @@ The BAAI/bge-m3 model leverages token classification for sparse retrieval. For m
| `Qwen3ForTokenClassification`<sup>C</sup> | Qwen3-based | `bd2lcco/Qwen3-0.6B-finetuned` | | |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./README.md#model-conversion))
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./README.md#model-conversion))
\* Feature support is the same as that of the original model.
If your model is not in the above list, we will try to automatically convert the model using
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
### Multimodal Models
!!! note
For more information about multimodal models inputs, see [this page](../supported_models.md#list-of-multimodal-language-models).
| Architecture | Models | Inputs | Example HF Models | [LoRA](../../features/lora.md) | [PP](../../serving/parallelism_scaling.md) |
| --------------------------------------------- | ------------------- | ----------------- | ------------------------------------------ | ------------------------------ | ------------------------------------------ |
| `Qwen3ASRForcedAlignerForTokenClassification` | Qwen3-ForcedAligner | T + A<sup>+</sup> | `Qwen/Qwen3-ForcedAligner-0.6B` (see note) | | ✅︎ |
!!! note
Forced alignment usage requires `--hf-overrides '{"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]}'`.
Please refer to [examples/pooling/token_classify/forced_alignment_offline.py](../../../examples/pooling/token_classify/forced_alignment_offline.py).
### As Reward Models
Using token classification models as reward models. For details on reward models, see [Reward Models](reward.md).

View File

@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from Qwen3-ForcedAligner inference:
# https://github.com/QwenLM/Qwen3-ASR
"""
Offline forced alignment example using Qwen3-ForcedAligner-0.6B.
Forced alignment takes audio and reference text as input and produces
word-level timestamps. The model predicts a time bin at each <timestamp>
token position; multiplying by ``timestamp_segment_time`` gives milliseconds.
Usage::
python forced_alignment_offline.py \
--model Qwen/Qwen3-ForcedAligner-0.6B
"""
from argparse import Namespace
import numpy as np
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
parser.set_defaults(
model="Qwen/Qwen3-ForcedAligner-0.6B",
runner="pooling",
enforce_eager=True,
hf_overrides={"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]},
)
return parser.parse_args()
def build_prompt(words: list[str]) -> str:
"""Build the forced alignment prompt from a word list.
Format: <|audio_start|><|audio_pad|><|audio_end|>
word1<timestamp><timestamp>word2<timestamp><timestamp>...
"""
body = "<timestamp><timestamp>".join(words) + "<timestamp><timestamp>"
return f"<|audio_start|><|audio_pad|><|audio_end|>{body}"
def main(args: Namespace):
llm = LLM(**vars(args))
config = llm.llm_engine.vllm_config.model_config.hf_config
timestamp_token_id = config.timestamp_token_id
timestamp_segment_time = config.timestamp_segment_time
# Example: align these words against a 5-second audio clip
words = ["Hello", "world"]
prompt = build_prompt(words)
# Use a 5-second silent audio as placeholder (replace with real audio)
sample_rate = 16000
audio = np.zeros(sample_rate * 5, dtype=np.float32)
outputs = llm.encode(
[{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
pooling_task="token_classify",
)
for output in outputs:
logits = output.outputs.data # [num_tokens, classify_num]
predictions = logits.argmax(dim=-1)
token_ids = output.prompt_token_ids
# Extract timestamps at <timestamp> positions
ts_predictions = [
pred.item() * timestamp_segment_time
for tid, pred in zip(token_ids, predictions)
if tid == timestamp_token_id
]
# Pair up start/end times per word
for i, word in enumerate(words):
start_ms = ts_predictions[i * 2]
end_ms = ts_predictions[i * 2 + 1]
print(f"{word:15s} {start_ms / 1000:.3f}s - {end_ms / 1000:.3f}s")
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
import torch
MODEL = "Qwen/Qwen3-ForcedAligner-0.6B"
CLASSIFY_NUM = 5000
TIMESTAMP_TOKEN_ID = 151705
def build_prompt(words: list[str]) -> str:
body = "<timestamp><timestamp>".join(words) + "<timestamp><timestamp>"
return f"<|audio_start|><|audio_pad|><|audio_end|>{body}"
@pytest.mark.parametrize("model", [MODEL])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@torch.inference_mode()
def test_qwen3_forced_aligner(
vllm_runner,
model: str,
dtype: str,
) -> None:
words = ["Hello", "world"]
prompt = build_prompt(words)
# 5-second silent audio at 16kHz
audio = np.zeros(16000 * 5, dtype=np.float32)
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
enforce_eager=True,
max_model_len=512,
hf_overrides={
"architectures": [
"Qwen3ASRForcedAlignerForTokenClassification",
],
},
) as vllm_model:
outputs = vllm_model.llm.encode(
[{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
pooling_task="token_classify",
)
# Validate output structure
assert len(outputs) == 1
logits = outputs[0].outputs.data
assert logits.dim() == 2
assert logits.shape[1] == CLASSIFY_NUM
# Validate timestamp extraction
token_ids = outputs[0].prompt_token_ids
predictions = logits.argmax(dim=-1)
ts_indices = [i for i, t in enumerate(token_ids) if t == TIMESTAMP_TOKEN_ID]
# 2 words x 2 timestamps each (start + end) = 4
assert len(ts_indices) == 4
ts_preds = [predictions[i].item() for i in ts_indices]
assert all(p >= 0 for p in ts_preds)
# end >= start for each word
assert ts_preds[1] >= ts_preds[0] # Hello
assert ts_preds[3] >= ts_preds[2] # world

View File

@@ -1094,6 +1094,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
min_transformers_version="4.57",
hf_overrides={"architectures": ["Qwen3ASRRealtimeGeneration"]},
),
"Qwen3ASRForcedAlignerForTokenClassification": _HfExamplesInfo(
"Qwen/Qwen3-ForcedAligner-0.6B",
max_model_len=4096,
min_transformers_version="4.57",
hf_overrides={"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]},
),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo(
"Skywork/Skywork-R1V-38B", trust_remote_code=True

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3-ASR ForcedAligner model (token classification)."""
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.model_executor.models.interfaces_base import default_pooling_type
from vllm.model_executor.models.qwen3_asr import (
Qwen3ASRDummyInputsBuilder,
Qwen3ASRForConditionalGeneration,
Qwen3ASRMultiModalProcessor,
Qwen3ASRProcessingInfo,
)
from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
@default_pooling_type(tok_pooling_type="ALL")
@MULTIMODAL_REGISTRY.register_processor(
Qwen3ASRMultiModalProcessor,
info=Qwen3ASRProcessingInfo,
dummy_inputs=Qwen3ASRDummyInputsBuilder,
)
class Qwen3ASRForcedAlignerForTokenClassification(
Qwen3ASRForConditionalGeneration,
):
"""Qwen3-ASR Forced Aligner model for per-token timestamp classification.
This model shares the audio tower and language model backbone with
Qwen3-ASR, but replaces the LM head with a classification head that
predicts time bins at ``<timestamp>`` token positions.
Usage::
llm = LLM(
model="Qwen/Qwen3-ForcedAligner-0.6B",
runner="pooling",
hf_overrides={
"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]
},
)
outputs = llm.encode(
[{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
pooling_task="token_classify",
)
"""
is_pooling_model = True
# Map thinker.lm_head -> classifier (not language_model.lm_head)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"thinker.lm_head.": "classifier.",
"thinker.model.": "language_model.model.",
"thinker.": "",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
thinker_config = config.thinker_config
# Remove the unused generation head created by the base class;
# the forced aligner uses a classifier head instead.
self.language_model.lm_head = None
self.language_model.logits_processor = None
self.classify_num = thinker_config.classify_num
# Classification head replaces lm_head for time-bin prediction.
# Use model dtype (not head_dtype which defaults to float32 for
# pooling models) to match the hidden state dtype.
self.classifier = nn.Linear(
thinker_config.text_config.hidden_size,
self.classify_num,
bias=False,
dtype=vllm_config.model_config.dtype,
)
# Token-level pooler to split per-token logits per request
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = pooler_for_token_classify(pooler_config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is not None:
inputs_embeds = None
# Run through language model backbone (transformer layers only)
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
# Apply classification head -> [num_tokens, classify_num]
return self.classifier(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=["talker.", "code2wav."],
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@@ -292,6 +292,10 @@ _TOKEN_CLASSIFICATION_MODELS = {
"modernbert",
"ModernBertForTokenClassification",
),
"Qwen3ASRForcedAlignerForTokenClassification": (
"qwen3_asr_forced_aligner",
"Qwen3ASRForcedAlignerForTokenClassification",
),
}
_SEQUENCE_CLASSIFICATION_MODELS = {

View File

@@ -342,12 +342,14 @@ class Qwen3ASRThinkerConfig(PretrainedConfig):
audio_start_token_id=151647,
user_token_id=872,
initializer_range=0.02,
classify_num=None,
**kwargs,
):
super().__init__(**kwargs)
self.user_token_id = user_token_id
self.audio_start_token_id = audio_start_token_id
self.initializer_range = initializer_range
self.classify_num = classify_num
if isinstance(audio_config, dict):
audio_config = Qwen3ASRAudioEncoderConfig(**audio_config)
@@ -406,6 +408,8 @@ class Qwen3ASRConfig(PretrainedConfig):
self,
thinker_config=None,
support_languages=None,
timestamp_token_id=None,
timestamp_segment_time=None,
**kwargs,
):
if thinker_config is None:
@@ -416,6 +420,8 @@ class Qwen3ASRConfig(PretrainedConfig):
self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config)
self.support_languages = support_languages
self.timestamp_token_id = timestamp_token_id
self.timestamp_segment_time = timestamp_segment_time
super().__init__(**kwargs)
def get_text_config(self, decoder=False) -> "PretrainedConfig":