[Feature] Add Qwen3-ForcedAligner support via token classification pooling (#35367)
Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -15,7 +15,7 @@ Many classification models support both (sequence) classification and token clas
|
||||
|
||||
!!! note
|
||||
|
||||
Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task (classify) is not
|
||||
Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task (classify) is not
|
||||
what you want, you need to manually specify it via `PoolerConfig(task="token_classify")` offline or
|
||||
`--pooler-config.task token_classify` online.
|
||||
|
||||
@@ -29,6 +29,12 @@ Offline: [examples/pooling/token_classify/ner_offline.py](../../../examples/pool
|
||||
|
||||
Online: [examples/pooling/token_classify/ner_online.py](../../../examples/pooling/token_classify/ner_online.py)
|
||||
|
||||
### Forced Alignment
|
||||
|
||||
Forced alignment takes audio and reference text as input and produces word-level timestamps.
|
||||
|
||||
Offline: [examples/pooling/token_classify/forced_alignment_offline.py](../../../examples/pooling/token_classify/forced_alignment_offline.py)
|
||||
|
||||
### Sparse retrieval (lexical matching)
|
||||
|
||||
The BAAI/bge-m3 model leverages token classification for sparse retrieval. For more information, see [this page](specific_models.md#baaibge-m3).
|
||||
@@ -43,12 +49,25 @@ The BAAI/bge-m3 model leverages token classification for sparse retrieval. For m
|
||||
| `Qwen3ForTokenClassification`<sup>C</sup> | Qwen3-based | `bd2lcco/Qwen3-0.6B-finetuned` | | |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./README.md#model-conversion))
|
||||
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./README.md#model-conversion))
|
||||
\* Feature support is the same as that of the original model.
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
|
||||
|
||||
### Multimodal Models
|
||||
|
||||
!!! note
|
||||
For more information about multimodal models inputs, see [this page](../supported_models.md#list-of-multimodal-language-models).
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../../features/lora.md) | [PP](../../serving/parallelism_scaling.md) |
|
||||
| --------------------------------------------- | ------------------- | ----------------- | ------------------------------------------ | ------------------------------ | ------------------------------------------ |
|
||||
| `Qwen3ASRForcedAlignerForTokenClassification` | Qwen3-ForcedAligner | T + A<sup>+</sup> | `Qwen/Qwen3-ForcedAligner-0.6B` (see note) | | ✅︎ |
|
||||
|
||||
!!! note
|
||||
Forced alignment usage requires `--hf-overrides '{"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]}'`.
|
||||
Please refer to [examples/pooling/token_classify/forced_alignment_offline.py](../../../examples/pooling/token_classify/forced_alignment_offline.py).
|
||||
|
||||
### As Reward Models
|
||||
|
||||
Using token classification models as reward models. For details on reward models, see [Reward Models](reward.md).
|
||||
|
||||
90
examples/pooling/token_classify/forced_alignment_offline.py
Normal file
90
examples/pooling/token_classify/forced_alignment_offline.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from Qwen3-ForcedAligner inference:
|
||||
# https://github.com/QwenLM/Qwen3-ASR
|
||||
|
||||
"""
|
||||
Offline forced alignment example using Qwen3-ForcedAligner-0.6B.
|
||||
|
||||
Forced alignment takes audio and reference text as input and produces
|
||||
word-level timestamps. The model predicts a time bin at each <timestamp>
|
||||
token position; multiplying by ``timestamp_segment_time`` gives milliseconds.
|
||||
|
||||
Usage::
|
||||
|
||||
python forced_alignment_offline.py \
|
||||
--model Qwen/Qwen3-ForcedAligner-0.6B
|
||||
"""
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
parser.set_defaults(
|
||||
model="Qwen/Qwen3-ForcedAligner-0.6B",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
hf_overrides={"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]},
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_prompt(words: list[str]) -> str:
|
||||
"""Build the forced alignment prompt from a word list.
|
||||
|
||||
Format: <|audio_start|><|audio_pad|><|audio_end|>
|
||||
word1<timestamp><timestamp>word2<timestamp><timestamp>...
|
||||
"""
|
||||
body = "<timestamp><timestamp>".join(words) + "<timestamp><timestamp>"
|
||||
return f"<|audio_start|><|audio_pad|><|audio_end|>{body}"
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
config = llm.llm_engine.vllm_config.model_config.hf_config
|
||||
timestamp_token_id = config.timestamp_token_id
|
||||
timestamp_segment_time = config.timestamp_segment_time
|
||||
|
||||
# Example: align these words against a 5-second audio clip
|
||||
words = ["Hello", "world"]
|
||||
prompt = build_prompt(words)
|
||||
|
||||
# Use a 5-second silent audio as placeholder (replace with real audio)
|
||||
sample_rate = 16000
|
||||
audio = np.zeros(sample_rate * 5, dtype=np.float32)
|
||||
|
||||
outputs = llm.encode(
|
||||
[{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
|
||||
pooling_task="token_classify",
|
||||
)
|
||||
|
||||
for output in outputs:
|
||||
logits = output.outputs.data # [num_tokens, classify_num]
|
||||
predictions = logits.argmax(dim=-1)
|
||||
token_ids = output.prompt_token_ids
|
||||
|
||||
# Extract timestamps at <timestamp> positions
|
||||
ts_predictions = [
|
||||
pred.item() * timestamp_segment_time
|
||||
for tid, pred in zip(token_ids, predictions)
|
||||
if tid == timestamp_token_id
|
||||
]
|
||||
|
||||
# Pair up start/end times per word
|
||||
for i, word in enumerate(words):
|
||||
start_ms = ts_predictions[i * 2]
|
||||
end_ms = ts_predictions[i * 2 + 1]
|
||||
print(f"{word:15s} {start_ms / 1000:.3f}s - {end_ms / 1000:.3f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
MODEL = "Qwen/Qwen3-ForcedAligner-0.6B"
|
||||
CLASSIFY_NUM = 5000
|
||||
TIMESTAMP_TOKEN_ID = 151705
|
||||
|
||||
|
||||
def build_prompt(words: list[str]) -> str:
|
||||
body = "<timestamp><timestamp>".join(words) + "<timestamp><timestamp>"
|
||||
return f"<|audio_start|><|audio_pad|><|audio_end|>{body}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [MODEL])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@torch.inference_mode()
|
||||
def test_qwen3_forced_aligner(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
words = ["Hello", "world"]
|
||||
prompt = build_prompt(words)
|
||||
|
||||
# 5-second silent audio at 16kHz
|
||||
audio = np.zeros(16000 * 5, dtype=np.float32)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
enforce_eager=True,
|
||||
max_model_len=512,
|
||||
hf_overrides={
|
||||
"architectures": [
|
||||
"Qwen3ASRForcedAlignerForTokenClassification",
|
||||
],
|
||||
},
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.llm.encode(
|
||||
[{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
|
||||
pooling_task="token_classify",
|
||||
)
|
||||
|
||||
# Validate output structure
|
||||
assert len(outputs) == 1
|
||||
logits = outputs[0].outputs.data
|
||||
assert logits.dim() == 2
|
||||
assert logits.shape[1] == CLASSIFY_NUM
|
||||
|
||||
# Validate timestamp extraction
|
||||
token_ids = outputs[0].prompt_token_ids
|
||||
predictions = logits.argmax(dim=-1)
|
||||
ts_indices = [i for i, t in enumerate(token_ids) if t == TIMESTAMP_TOKEN_ID]
|
||||
|
||||
# 2 words x 2 timestamps each (start + end) = 4
|
||||
assert len(ts_indices) == 4
|
||||
|
||||
ts_preds = [predictions[i].item() for i in ts_indices]
|
||||
assert all(p >= 0 for p in ts_preds)
|
||||
# end >= start for each word
|
||||
assert ts_preds[1] >= ts_preds[0] # Hello
|
||||
assert ts_preds[3] >= ts_preds[2] # world
|
||||
@@ -1094,6 +1094,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
min_transformers_version="4.57",
|
||||
hf_overrides={"architectures": ["Qwen3ASRRealtimeGeneration"]},
|
||||
),
|
||||
"Qwen3ASRForcedAlignerForTokenClassification": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-ForcedAligner-0.6B",
|
||||
max_model_len=4096,
|
||||
min_transformers_version="4.57",
|
||||
hf_overrides={"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]},
|
||||
),
|
||||
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True),
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo(
|
||||
"Skywork/Skywork-R1V-38B", trust_remote_code=True
|
||||
|
||||
120
vllm/model_executor/models/qwen3_asr_forced_aligner.py
Normal file
120
vllm/model_executor/models/qwen3_asr_forced_aligner.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only Qwen3-ASR ForcedAligner model (token classification)."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
|
||||
from vllm.model_executor.models.interfaces_base import default_pooling_type
|
||||
from vllm.model_executor.models.qwen3_asr import (
|
||||
Qwen3ASRDummyInputsBuilder,
|
||||
Qwen3ASRForConditionalGeneration,
|
||||
Qwen3ASRMultiModalProcessor,
|
||||
Qwen3ASRProcessingInfo,
|
||||
)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen3ASRMultiModalProcessor,
|
||||
info=Qwen3ASRProcessingInfo,
|
||||
dummy_inputs=Qwen3ASRDummyInputsBuilder,
|
||||
)
|
||||
class Qwen3ASRForcedAlignerForTokenClassification(
|
||||
Qwen3ASRForConditionalGeneration,
|
||||
):
|
||||
"""Qwen3-ASR Forced Aligner model for per-token timestamp classification.
|
||||
|
||||
This model shares the audio tower and language model backbone with
|
||||
Qwen3-ASR, but replaces the LM head with a classification head that
|
||||
predicts time bins at ``<timestamp>`` token positions.
|
||||
|
||||
Usage::
|
||||
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen3-ForcedAligner-0.6B",
|
||||
runner="pooling",
|
||||
hf_overrides={
|
||||
"architectures": ["Qwen3ASRForcedAlignerForTokenClassification"]
|
||||
},
|
||||
)
|
||||
outputs = llm.encode(
|
||||
[{"prompt": prompt, "multi_modal_data": {"audio": audio}}],
|
||||
pooling_task="token_classify",
|
||||
)
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
# Map thinker.lm_head -> classifier (not language_model.lm_head)
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"thinker.lm_head.": "classifier.",
|
||||
"thinker.model.": "language_model.model.",
|
||||
"thinker.": "",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
thinker_config = config.thinker_config
|
||||
|
||||
# Remove the unused generation head created by the base class;
|
||||
# the forced aligner uses a classifier head instead.
|
||||
self.language_model.lm_head = None
|
||||
self.language_model.logits_processor = None
|
||||
|
||||
self.classify_num = thinker_config.classify_num
|
||||
|
||||
# Classification head replaces lm_head for time-bin prediction.
|
||||
# Use model dtype (not head_dtype which defaults to float32 for
|
||||
# pooling models) to match the hidden state dtype.
|
||||
self.classifier = nn.Linear(
|
||||
thinker_config.text_config.hidden_size,
|
||||
self.classify_num,
|
||||
bias=False,
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
)
|
||||
|
||||
# Token-level pooler to split per-token logits per request
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# Run through language model backbone (transformer layers only)
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
# Apply classification head -> [num_tokens, classify_num]
|
||||
return self.classifier(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=["talker.", "code2wav."],
|
||||
)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
@@ -292,6 +292,10 @@ _TOKEN_CLASSIFICATION_MODELS = {
|
||||
"modernbert",
|
||||
"ModernBertForTokenClassification",
|
||||
),
|
||||
"Qwen3ASRForcedAlignerForTokenClassification": (
|
||||
"qwen3_asr_forced_aligner",
|
||||
"Qwen3ASRForcedAlignerForTokenClassification",
|
||||
),
|
||||
}
|
||||
|
||||
_SEQUENCE_CLASSIFICATION_MODELS = {
|
||||
|
||||
@@ -342,12 +342,14 @@ class Qwen3ASRThinkerConfig(PretrainedConfig):
|
||||
audio_start_token_id=151647,
|
||||
user_token_id=872,
|
||||
initializer_range=0.02,
|
||||
classify_num=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.user_token_id = user_token_id
|
||||
self.audio_start_token_id = audio_start_token_id
|
||||
self.initializer_range = initializer_range
|
||||
self.classify_num = classify_num
|
||||
|
||||
if isinstance(audio_config, dict):
|
||||
audio_config = Qwen3ASRAudioEncoderConfig(**audio_config)
|
||||
@@ -406,6 +408,8 @@ class Qwen3ASRConfig(PretrainedConfig):
|
||||
self,
|
||||
thinker_config=None,
|
||||
support_languages=None,
|
||||
timestamp_token_id=None,
|
||||
timestamp_segment_time=None,
|
||||
**kwargs,
|
||||
):
|
||||
if thinker_config is None:
|
||||
@@ -416,6 +420,8 @@ class Qwen3ASRConfig(PretrainedConfig):
|
||||
|
||||
self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config)
|
||||
self.support_languages = support_languages
|
||||
self.timestamp_token_id = timestamp_token_id
|
||||
self.timestamp_segment_time = timestamp_segment_time
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
||||
|
||||
Reference in New Issue
Block a user