[ASR] Fix spacing bw chunks in multi chunk audio transcription (#39116)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,271 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""ASR inter-chunk spacing: ``asr_inter_chunk_separator`` and transcription
|
||||||
|
serving (mocked).
|
||||||
|
|
||||||
|
Unit tests cover the helper and ``SupportsTranscription.no_space_languages``.
|
||||||
|
Integration-style tests exercise ``OpenAIServingTranscription`` streaming and
|
||||||
|
``create_transcription`` without loading a model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||||
|
from vllm.entrypoints.openai.engine.protocol import (
|
||||||
|
ErrorResponse,
|
||||||
|
RequestResponseMetadata,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.openai.speech_to_text.protocol import TranscriptionRequest
|
||||||
|
from vllm.entrypoints.openai.speech_to_text.serving import OpenAIServingTranscription
|
||||||
|
from vllm.entrypoints.openai.speech_to_text.speech_to_text import (
|
||||||
|
OpenAISpeechToText,
|
||||||
|
asr_inter_chunk_separator,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsTranscription
|
||||||
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
|
|
||||||
|
# --- Unit: helper + protocol -------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_no_space_languages_includes_zh_and_ja():
|
||||||
|
assert SupportsTranscription.no_space_languages == {"ja", "zh"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("language", "expected_sep"),
|
||||||
|
[
|
||||||
|
("en", " "),
|
||||||
|
("EN", " "),
|
||||||
|
("zh", ""),
|
||||||
|
("ZH", ""),
|
||||||
|
("ja", ""),
|
||||||
|
(None, " "),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_asr_inter_chunk_separator_matches_protocol(language, expected_sep):
|
||||||
|
sep = asr_inter_chunk_separator(language, SupportsTranscription.no_space_languages)
|
||||||
|
assert sep == expected_sep
|
||||||
|
|
||||||
|
|
||||||
|
def test_joined_chunks_english_has_space_between():
|
||||||
|
sep = asr_inter_chunk_separator("en", SupportsTranscription.no_space_languages)
|
||||||
|
assert sep.join(["hello", "world"]) == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_joined_chunks_chinese_has_no_space_between():
|
||||||
|
sep = asr_inter_chunk_separator("zh", SupportsTranscription.no_space_languages)
|
||||||
|
assert sep.join(["你好", "世界"]) == "你好世界"
|
||||||
|
|
||||||
|
|
||||||
|
# --- Integration: serving (no model) -----------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _StubTranscriptionModel:
|
||||||
|
"""Minimal stand-in for a SupportsTranscription implementation (no torch)."""
|
||||||
|
|
||||||
|
no_space_languages: set[str] = {"ja", "zh"}
|
||||||
|
supports_segment_timestamp = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_speech_to_text_config(
|
||||||
|
cls, model_config: ModelConfig, task_type: str
|
||||||
|
) -> SpeechToTextConfig:
|
||||||
|
return SpeechToTextConfig(
|
||||||
|
sample_rate=16000.0,
|
||||||
|
max_audio_clip_s=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post_process_output(cls, text: str) -> str:
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _request_output(text: str) -> RequestOutput:
|
||||||
|
return RequestOutput(
|
||||||
|
request_id="rid",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text=text,
|
||||||
|
token_ids=(1, 2, 3),
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
finished=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sse_delta_contents(sse_body: str) -> list[str]:
|
||||||
|
"""Extract ``choices[0].delta.content`` from each ``data:`` line (streaming API)."""
|
||||||
|
contents: list[str] = []
|
||||||
|
for line in sse_body.splitlines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
payload = line.removeprefix("data: ").strip()
|
||||||
|
if payload == "[DONE]":
|
||||||
|
continue
|
||||||
|
obj = json.loads(payload)
|
||||||
|
for choice in obj.get("choices") or []:
|
||||||
|
delta = choice.get("delta") or {}
|
||||||
|
if "content" in delta:
|
||||||
|
contents.append(delta["content"])
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcription_stream_generator_english_inserts_space_between_chunks():
|
||||||
|
"""Online streaming: first output per audio chunk is prefixed with *separator*."""
|
||||||
|
|
||||||
|
async def gen_hello() -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
yield _request_output("hello")
|
||||||
|
|
||||||
|
async def gen_world() -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
yield _request_output("world")
|
||||||
|
|
||||||
|
serving = OpenAIServingTranscription.__new__(OpenAIServingTranscription)
|
||||||
|
serving.enable_force_include_usage = False
|
||||||
|
serving.model_cls = _StubTranscriptionModel
|
||||||
|
serving.task_type = "transcribe"
|
||||||
|
request = SimpleNamespace(
|
||||||
|
model="stub-model",
|
||||||
|
stream_include_usage=False,
|
||||||
|
stream_continuous_usage_stats=False,
|
||||||
|
)
|
||||||
|
sep = asr_inter_chunk_separator("en", _StubTranscriptionModel.no_space_languages)
|
||||||
|
assert sep == " "
|
||||||
|
|
||||||
|
out_lines: list[str] = []
|
||||||
|
agen = OpenAIServingTranscription.transcription_stream_generator(
|
||||||
|
serving,
|
||||||
|
request=request,
|
||||||
|
result_generator=[gen_hello(), gen_world()],
|
||||||
|
request_id="test-req",
|
||||||
|
request_metadata=RequestResponseMetadata(request_id="test-req"),
|
||||||
|
audio_duration_s=1.0,
|
||||||
|
separator=sep,
|
||||||
|
)
|
||||||
|
async for line in agen:
|
||||||
|
out_lines.append(line)
|
||||||
|
sse = "".join(out_lines)
|
||||||
|
combined = "".join(_sse_delta_contents(sse))
|
||||||
|
assert combined.strip() == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcription_stream_generator_chinese_no_space_between_chunks():
|
||||||
|
async def gen_a() -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
yield _request_output("你好")
|
||||||
|
|
||||||
|
async def gen_b() -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
yield _request_output("世界")
|
||||||
|
|
||||||
|
serving = OpenAIServingTranscription.__new__(OpenAIServingTranscription)
|
||||||
|
serving.enable_force_include_usage = False
|
||||||
|
serving.model_cls = _StubTranscriptionModel
|
||||||
|
serving.task_type = "transcribe"
|
||||||
|
request = SimpleNamespace(
|
||||||
|
model="stub-model",
|
||||||
|
stream_include_usage=False,
|
||||||
|
stream_continuous_usage_stats=False,
|
||||||
|
)
|
||||||
|
sep = asr_inter_chunk_separator("zh", _StubTranscriptionModel.no_space_languages)
|
||||||
|
assert sep == ""
|
||||||
|
|
||||||
|
out_lines: list[str] = []
|
||||||
|
agen = OpenAIServingTranscription.transcription_stream_generator(
|
||||||
|
serving,
|
||||||
|
request=request,
|
||||||
|
result_generator=[gen_a(), gen_b()],
|
||||||
|
request_id="test-req-zh",
|
||||||
|
request_metadata=RequestResponseMetadata(request_id="test-req-zh"),
|
||||||
|
audio_duration_s=1.0,
|
||||||
|
separator=sep,
|
||||||
|
)
|
||||||
|
async for line in agen:
|
||||||
|
out_lines.append(line)
|
||||||
|
combined = "".join(_sse_delta_contents("".join(out_lines)))
|
||||||
|
assert combined == "你好世界"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_transcription_non_streaming_joins_chunks_by_language():
|
||||||
|
"""``create_transcription`` uses the same separator logic as the helper."""
|
||||||
|
|
||||||
|
async def gen_hello() -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
yield _request_output("hello")
|
||||||
|
|
||||||
|
async def gen_world() -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
yield _request_output("world")
|
||||||
|
|
||||||
|
engine_client = MagicMock()
|
||||||
|
engine_client.model_config = MagicMock()
|
||||||
|
engine_client.model_config.get_diff_sampling_param.return_value = {
|
||||||
|
"max_tokens": 256,
|
||||||
|
"temperature": 0.0,
|
||||||
|
}
|
||||||
|
engine_client.model_config.max_model_len = 8192
|
||||||
|
engine_client.errored = False
|
||||||
|
engine_client.generate.side_effect = [gen_hello(), gen_world()]
|
||||||
|
|
||||||
|
models = MagicMock(spec=OpenAIServingModels)
|
||||||
|
models.lora_requests = {}
|
||||||
|
models.is_base_model.return_value = True
|
||||||
|
|
||||||
|
preprocess_mock = AsyncMock(return_value=([MagicMock(), MagicMock()], 1.0))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"vllm.model_executor.model_loader.get_model_cls",
|
||||||
|
return_value=_StubTranscriptionModel,
|
||||||
|
),
|
||||||
|
patch.object(OpenAISpeechToText, "_preprocess_speech_to_text", preprocess_mock),
|
||||||
|
):
|
||||||
|
serving = OpenAIServingTranscription(engine_client, models, request_logger=None)
|
||||||
|
|
||||||
|
req_en = TranscriptionRequest.model_construct(
|
||||||
|
file=MagicMock(),
|
||||||
|
model="stub-model",
|
||||||
|
language="en",
|
||||||
|
stream=False,
|
||||||
|
response_format="json",
|
||||||
|
)
|
||||||
|
out_en = await serving.create_transcription(
|
||||||
|
b"\x00\x00", req_en, raw_request=None
|
||||||
|
)
|
||||||
|
assert not isinstance(out_en, ErrorResponse)
|
||||||
|
assert out_en.text == "hello world"
|
||||||
|
|
||||||
|
async def gen_nihao() -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
yield _request_output("你好")
|
||||||
|
|
||||||
|
async def gen_shijie() -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
yield _request_output("世界")
|
||||||
|
|
||||||
|
engine_client.generate.side_effect = [gen_nihao(), gen_shijie()]
|
||||||
|
|
||||||
|
req_zh = TranscriptionRequest.model_construct(
|
||||||
|
file=MagicMock(),
|
||||||
|
model="stub-model",
|
||||||
|
language="zh",
|
||||||
|
stream=False,
|
||||||
|
response_format="json",
|
||||||
|
)
|
||||||
|
out_zh = await serving.create_transcription(
|
||||||
|
b"\x00\x00", req_zh, raw_request=None
|
||||||
|
)
|
||||||
|
assert not isinstance(out_zh, ErrorResponse)
|
||||||
|
assert out_zh.text == "你好世界"
|
||||||
@@ -86,6 +86,7 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
request_metadata: RequestResponseMetadata,
|
request_metadata: RequestResponseMetadata,
|
||||||
audio_duration_s: float,
|
audio_duration_s: float,
|
||||||
|
separator: str,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
generator = self._speech_to_text_stream_generator(
|
generator = self._speech_to_text_stream_generator(
|
||||||
request=request,
|
request=request,
|
||||||
@@ -96,6 +97,7 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
|||||||
chunk_object_type="transcription.chunk",
|
chunk_object_type="transcription.chunk",
|
||||||
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
||||||
stream_response_class=TranscriptionStreamResponse,
|
stream_response_class=TranscriptionStreamResponse,
|
||||||
|
separator=separator,
|
||||||
)
|
)
|
||||||
async for chunk in generator:
|
async for chunk in generator:
|
||||||
yield chunk
|
yield chunk
|
||||||
@@ -157,6 +159,7 @@ class OpenAIServingTranslation(OpenAISpeechToText):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
request_metadata: RequestResponseMetadata,
|
request_metadata: RequestResponseMetadata,
|
||||||
audio_duration_s: float,
|
audio_duration_s: float,
|
||||||
|
separator: str,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
generator = self._speech_to_text_stream_generator(
|
generator = self._speech_to_text_stream_generator(
|
||||||
request=request,
|
request=request,
|
||||||
@@ -167,6 +170,7 @@ class OpenAIServingTranslation(OpenAISpeechToText):
|
|||||||
chunk_object_type="translation.chunk",
|
chunk_object_type="translation.chunk",
|
||||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||||
stream_response_class=TranslationStreamResponse,
|
stream_response_class=TranslationStreamResponse,
|
||||||
|
separator=separator,
|
||||||
)
|
)
|
||||||
async for chunk in generator:
|
async for chunk in generator:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import io
|
|||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
import zlib
|
import zlib
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable, Set
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Final, Literal, TypeAlias, TypeVar, cast
|
from typing import Final, Literal, TypeAlias, TypeVar, cast
|
||||||
|
|
||||||
@@ -69,6 +69,17 @@ ResponseType: TypeAlias = (
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def asr_inter_chunk_separator(
|
||||||
|
language: str | None, no_space_languages: Set[str]
|
||||||
|
) -> str:
|
||||||
|
"""Space to insert between ASR text chunks for streaming and non-streaming join.
|
||||||
|
|
||||||
|
Languages in ``no_space_languages`` (e.g. Chinese, Japanese) use an empty
|
||||||
|
separator; others use a single ASCII space.
|
||||||
|
"""
|
||||||
|
return "" if language and language.lower() in no_space_languages else " "
|
||||||
|
|
||||||
|
|
||||||
class OpenAISpeechToText(OpenAIServing):
|
class OpenAISpeechToText(OpenAIServing):
|
||||||
"""Base class for speech-to-text operations like transcription and
|
"""Base class for speech-to-text operations like transcription and
|
||||||
translation."""
|
translation."""
|
||||||
@@ -486,9 +497,18 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
|
|
||||||
list_result_generator.append(generator)
|
list_result_generator.append(generator)
|
||||||
|
|
||||||
|
separator = asr_inter_chunk_separator(
|
||||||
|
request.language, self.model_cls.no_space_languages
|
||||||
|
)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return stream_generator_method(
|
return stream_generator_method(
|
||||||
request, list_result_generator, request_id, request_metadata, duration_s
|
request,
|
||||||
|
list_result_generator,
|
||||||
|
request_id,
|
||||||
|
request_metadata,
|
||||||
|
duration_s,
|
||||||
|
separator,
|
||||||
)
|
)
|
||||||
# Non-streaming response.
|
# Non-streaming response.
|
||||||
total_segments = []
|
total_segments = []
|
||||||
@@ -500,7 +520,6 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
"translate": TranslationSegment,
|
"translate": TranslationSegment,
|
||||||
}
|
}
|
||||||
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
|
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
|
||||||
text = ""
|
|
||||||
chunk_size_in_s = self.asr_config.max_audio_clip_s
|
chunk_size_in_s = self.asr_config.max_audio_clip_s
|
||||||
if chunk_size_in_s is None:
|
if chunk_size_in_s is None:
|
||||||
assert len(list_result_generator) == 1, (
|
assert len(list_result_generator) == 1, (
|
||||||
@@ -528,7 +547,7 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
else:
|
else:
|
||||||
raw_text = op.outputs[0].text
|
raw_text = op.outputs[0].text
|
||||||
text_parts.append(self.model_cls.post_process_output(raw_text))
|
text_parts.append(self.model_cls.post_process_output(raw_text))
|
||||||
text = "".join(text_parts)
|
text = separator.join(text_parts)
|
||||||
if self.task_type == "transcribe":
|
if self.task_type == "transcribe":
|
||||||
final_response: ResponseType
|
final_response: ResponseType
|
||||||
# add usage in TranscriptionResponse.
|
# add usage in TranscriptionResponse.
|
||||||
@@ -581,6 +600,7 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
| type[TranslationResponseStreamChoice],
|
| type[TranslationResponseStreamChoice],
|
||||||
stream_response_class: type[TranscriptionStreamResponse]
|
stream_response_class: type[TranscriptionStreamResponse]
|
||||||
| type[TranslationStreamResponse],
|
| type[TranslationStreamResponse],
|
||||||
|
separator: str,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
model_name = request.model
|
model_name = request.model
|
||||||
@@ -597,6 +617,7 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for result_generator in list_result_generator:
|
for result_generator in list_result_generator:
|
||||||
|
beginning_of_chunk = True
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
# On first result.
|
# On first result.
|
||||||
if res.prompt_token_ids is not None:
|
if res.prompt_token_ids is not None:
|
||||||
@@ -614,6 +635,14 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
assert len(res.outputs) == 1
|
assert len(res.outputs) == 1
|
||||||
output = res.outputs[0]
|
output = res.outputs[0]
|
||||||
|
|
||||||
|
# dont add separator to the first chunk
|
||||||
|
if (
|
||||||
|
result_generator is not list_result_generator[0]
|
||||||
|
and beginning_of_chunk
|
||||||
|
):
|
||||||
|
output.text = separator + output.text
|
||||||
|
beginning_of_chunk = False
|
||||||
|
|
||||||
# TODO: For models that output structured formats (e.g.,
|
# TODO: For models that output structured formats (e.g.,
|
||||||
# Qwen3-ASR with "language X<asr_text>" prefix), streaming
|
# Qwen3-ASR with "language X<asr_text>" prefix), streaming
|
||||||
# would need buffering to strip the prefix properly since
|
# would need buffering to strip the prefix properly since
|
||||||
|
|||||||
@@ -2007,6 +2007,7 @@ class CohereAsrForConditionalGeneration(
|
|||||||
supports_transcription_only = True
|
supports_transcription_only = True
|
||||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||||
skip_warmup_audio_preprocessing = True
|
skip_warmup_audio_preprocessing = True
|
||||||
|
no_space_languages = {"ja", "zh"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_language(cls, language: str | None) -> str | None:
|
def validate_language(cls, language: str | None) -> str | None:
|
||||||
|
|||||||
@@ -1098,6 +1098,12 @@ class SupportsTranscription(Protocol):
|
|||||||
:meth:`get_language_token_ids`.
|
:meth:`get_language_token_ids`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
no_space_languages: ClassVar[set[str]] = {"ja", "zh"}
|
||||||
|
"""
|
||||||
|
Languages that don't need a space between words.
|
||||||
|
For example, Japanese (ja) and Chinese (zh) don't need a space between words.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
# language codes in supported_languages
|
# language codes in supported_languages
|
||||||
|
|||||||
Reference in New Issue
Block a user