From 6c0baee61025f258c6d56830d0150feab34c45ab Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 12 Feb 2026 18:46:43 +0100 Subject: [PATCH] [Voxtral Realtime] Refactor & Improve buffering logic (#34428) Signed-off-by: Patrick von Platen Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- requirements/common.txt | 4 +- requirements/nightly_torch_test.txt | 2 +- requirements/test.in | 2 +- requirements/test.txt | 2 +- .../generation/test_voxtral_realtime.py | 128 ++--------- vllm/model_executor/models/voxtral.py | 4 +- .../model_executor/models/voxtral_realtime.py | 203 +++++++++--------- 7 files changed, 121 insertions(+), 224 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index 297447cf2..ef320c5e2 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.17.0 -mistral_common[image] >= 1.9.0 +mistral_common[image] >= 1.9.1 opencv-python-headless >= 4.13.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 @@ -52,4 +52,4 @@ anthropic >= 0.71.0 model-hosting-container-standards >= 0.1.13, < 1.0.0 mcp grpcio -grpcio-reflection \ No newline at end of file +grpcio-reflection diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index a45634d0c..cc5ea519a 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -23,7 +23,7 @@ jiwer # required for audio tests timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.9.0 # required for voxtral test +mistral_common[image,audio] >= 1.9.1 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.13.0 # required for video test datamodel_code_generator # required for minicpm3 test diff --git a/requirements/test.in b/requirements/test.in index 8a97c0e88..1c43d4446 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -30,7 +30,7 @@ torchaudio==2.10.0 torchvision==0.25.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.9.0 # required for voxtral test +mistral_common[image,audio] >= 1.9.1 # required for voxtral test num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py opencv-python-headless >= 4.13.0 # required for video test diff --git a/requirements/test.txt b/requirements/test.txt index fbe3228d2..f2ab8037a 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -499,7 +499,7 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.9.0 +mistral-common==1.9.1 # via -r requirements/test.in mlflow==2.22.0 # via terratorch diff --git a/tests/models/multimodal/generation/test_voxtral_realtime.py b/tests/models/multimodal/generation/test_voxtral_realtime.py index 96f60bb5c..2b769e3ed 100644 --- a/tests/models/multimodal/generation/test_voxtral_realtime.py +++ b/tests/models/multimodal/generation/test_voxtral_realtime.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio from dataclasses import asdict import pytest @@ -10,14 +9,13 @@ from mistral_common.protocol.transcription.request import ( StreamingMode, TranscriptionRequest, ) -from mistral_common.tokens.tokenizers.audio import AudioConfig from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from vllm import LLM, EngineArgs, SamplingParams from vllm.assets.audio import AudioAsset from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.inputs.data import TokensPrompt -from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput +from vllm.model_executor.models.voxtral_realtime import VoxtralRealtimeBuffer +from vllm.v1.engine.async_llm import AsyncLLM MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602" ENGINE_CONFIG = dict( @@ -114,136 +112,40 @@ def test_voxtral_realtime_forward(audio_assets, tokenizer, engine): assert texts == EXPECTED_TEXT -class RealTimeAudioInput: - """ - This class is used to stream an audio file just as - if it would be streamed in real-time. - """ - - def __init__(self, tokenizer: MistralTokenizer) -> None: - self._tokenizer = tokenizer - self._config: AudioConfig = ( - self._tokenizer.instruct_tokenizer.audio_encoder.audio_config - ) - - self._look_ahead_in_ms = self._config.streaming_look_ahead_ms - self._look_back_in_ms = self._config.streaming_look_back_ms - - self._sampling_rate = self._config.sampling_rate - - self._audio: Audio | None = None - - # mutable objects - self._start = 0 - - n_left_pad_samples = ( - self._config.raw_audio_length_per_tok * self._config.n_left_pad_tokens - ) - self._end = self.streaming_delay + n_left_pad_samples + self.streaming_size - self._queue: asyncio.Queue[StreamingInput | None] = asyncio.Queue() - - @classmethod - async def create(cls, audio: Audio, tokenizer: MistralTokenizer): - self = cls(tokenizer) - - # we're doing "OFFLINE" encoding here to right & left pad the audio since - # we have access to the whole audio - # if we'd do an actual online realtime streaming application we - # should instead pass `StreamingMode.ONLINE` - req = TranscriptionRequest( - streaming=StreamingMode.OFFLINE, - audio=RawAudio.from_audio(audio), - language=None, - ) - audio_enc = self._tokenizer.encode_transcription(req) - self._audio = audio_enc.audios[0] - - # add first request - await self.add_tokens(audio_enc.tokens) - - return self - - @property - def look_ahead(self) -> int: - return self._get_len_in_samples(self._look_ahead_in_ms) - - @property - def look_back(self) -> int: - return self._get_len_in_samples(self._look_back_in_ms) - - @property - def streaming_delay(self) -> int: - return self._get_len_in_samples(self._config.transcription_delay_ms) - - @property - def streaming_size(self) -> int: - stream_size_in_ms = 1000 / self._config.frame_rate - return self._get_len_in_samples(stream_size_in_ms) - - def _get_len_in_samples(self, len_in_ms: float) -> int: - _len_in_s = self._sampling_rate * len_in_ms / 1000 - assert _len_in_s.is_integer(), _len_in_s - len_in_s = int(_len_in_s) - - return len_in_s - - async def add_tokens(self, tokens: list[int]) -> None: - assert self._audio is not None - if self._start >= len(self._audio.audio_array): - self.stop() - return - - _end = self._end + self.look_ahead - _start = max(0, self._start - self.look_back) - - multi_modal_data = {"audio": (self._audio.audio_array[_start:_end], None)} - - prompt = TokensPrompt( - prompt_token_ids=tokens, multi_modal_data=multi_modal_data - ) - - await self._queue.put(StreamingInput(prompt)) - - # increase - self._start = self._end - self._end = self._end + self.streaming_size - - def stop(self): - self._queue.put_nowait(None) - - async def generator(self): - while (item := await self._queue.get()) is not None: - yield item - - @pytest.mark.asyncio async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine): sampling_params = SamplingParams(temperature=0.0, max_tokens=1) + audio_config = tokenizer.instruct_tokenizer.audio_encoder.audio_config output_tokens_list = [] for i, audio_asset in enumerate(audio_assets): output_tokens = [] audio = Audio.from_file(audio_asset.get_local_path(), strict=False) - streaming_input = await RealTimeAudioInput.create( - audio=audio, tokenizer=tokenizer + + req = TranscriptionRequest( + streaming=StreamingMode.OFFLINE, + audio=RawAudio.from_audio(audio), + language=None, ) + audio_enc = tokenizer.encode_transcription(req) + + buffer = VoxtralRealtimeBuffer(audio_config, audio_enc.tokens) + await buffer.append_audio(audio_enc.audios[0].audio_array) + await buffer.append_audio(None) request_id = f"session-{i}" async for resp in async_engine.generate( - prompt=streaming_input.generator(), + prompt=buffer.get_input_stream(), sampling_params=sampling_params, request_id=request_id, ): tokens = resp.outputs[0].token_ids[-1:] - output_tokens.extend(tokens) - await streaming_input.add_tokens(tokens) + await buffer.append_tokens(tokens) output_tokens_list.append(output_tokens) texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list] - texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my") - assert texts == EXPECTED_TEXT diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 2dbfe0a95..cc9856f28 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -155,9 +155,7 @@ class VoxtralProcessorAdapter: assert audio.ndim == 1 if not self._audio_processor.audio_config.is_streaming: - audio = self._audio_processor.pad( - audio, self.sampling_rate, is_online_streaming=False - ) + audio = self._audio_processor.pad(audio, self.sampling_rate) audio_tokens = [self.begin_audio_token_id] + [ self.audio_token_id diff --git a/vllm/model_executor/models/voxtral_realtime.py b/vllm/model_executor/models/voxtral_realtime.py index 6c4d20d35..81406c66b 100644 --- a/vllm/model_executor/models/voxtral_realtime.py +++ b/vllm/model_executor/models/voxtral_realtime.py @@ -3,7 +3,7 @@ import asyncio import math -from collections.abc import AsyncGenerator, Mapping +from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping from typing import Literal import numpy as np @@ -18,7 +18,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S -from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.data import PromptType, StreamingInput, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime from vllm.model_executor.models.voxtral import ( @@ -47,8 +47,6 @@ from .utils import ( logger = init_logger(__name__) -_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30 - class VoxtralRealtimeMultiModalProcessor(VoxtralMultiModalProcessor): def __init__( @@ -130,84 +128,81 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor: class VoxtralRealtimeBuffer: - def __init__(self, config: AudioConfig) -> None: + def __init__(self, config: AudioConfig, prompt_tokens: list[int]) -> None: self._config = config - self._look_ahead_in_ms = config.streaming_look_ahead_ms - self._look_back_in_ms = config.streaming_look_back_ms + _look_ahead_in_ms = self._config.streaming_look_ahead_ms + _look_back_in_ms = self._config.streaming_look_back_ms + self._look_ahead_in_samples = self._ms_to_samples(_look_ahead_in_ms) + self._look_back_in_samples = self._ms_to_samples(_look_back_in_ms) - self._sampling_rate = self._config.sampling_rate + # None signals the end + self._audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue() + self._leftover: np.ndarray | None = None + self._token_queue: asyncio.Queue[int] = asyncio.Queue() - self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms) - self._look_back = self._get_len_in_samples(self._look_back_in_ms) - self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate) + self._initial_end = len(prompt_tokens) * self._config.raw_audio_length_per_tok + for token in prompt_tokens: + self._token_queue.put_nowait(token) - # mutable objects - streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms) - self._start = 0 - self._end = streaming_delay + self._streaming_size + def _generate_frame_size_and_num_tokens(self) -> Iterator[tuple[int, int]]: + streaming_step_size = self._ms_to_samples(1000 / self._config.frame_rate) + start = 0 + end = self._initial_end + while True: + frame_start = max(start - self._look_back_in_samples, 0) + frame_end = end + self._look_ahead_in_samples + frame_size = frame_end - frame_start + num_tokens = (end - start) / self._config.raw_audio_length_per_tok + assert num_tokens.is_integer() + yield frame_size, int(num_tokens) + start = end + end += streaming_step_size - # always pre-allocate 30 second buffers - self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate - self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32) - self._filled_buffer_len = 0 + def _ms_to_samples(self, ms: float) -> int: + len_ = self._config.sampling_rate * ms / 1000 + assert len_.is_integer(), len_ + return int(len_) - @property - def start_idx(self): - return max(self._start - self._look_back, 0) + async def append_audio(self, audio_array: np.ndarray | None) -> None: + await self._audio_queue.put(audio_array) - @property - def end_idx(self): - return self._end + self._look_ahead + async def append_tokens(self, tokens: Iterable[int]) -> None: + for token in tokens: + await self._token_queue.put(token) - @property - def is_audio_complete(self) -> bool: - return self._filled_buffer_len >= self.end_idx + async def get_input_stream(self) -> AsyncGenerator[StreamingInput]: + for frame_size, num_tokens in self._generate_frame_size_and_num_tokens(): + next_tokens = [await self._token_queue.get() for _ in range(num_tokens)] - def _get_len_in_samples(self, len_in_ms: float) -> int: - _len_in_s = self._sampling_rate * len_in_ms / 1000 - assert _len_in_s.is_integer(), _len_in_s - len_in_s = int(_len_in_s) + audio_arrays: list[np.ndarray] = ( + [self._leftover] if self._leftover is not None else [] + ) + while sum(len(arr) for arr in audio_arrays) < frame_size: + arr = await self._audio_queue.get() + if arr is None: + return + audio_arrays.append(arr) - return len_in_s + audio_array = np.concatenate(audio_arrays) + frame = audio_array[:frame_size] - def _allocate_new_buffer(self) -> None: - # allocate new buffer - new_buffer = np.empty(self._buffer_size, dtype=np.float32) - left_to_copy = max(self._filled_buffer_len - self.start_idx, 0) + # The current stride took look_ahead_in_samples audio of the next sample + # In addition the next sample will take look_back_in_samples audio of + # the current sample => So let's put both of this into the leftover + stride = ( + frame_size - self._look_ahead_in_samples - self._look_back_in_samples + ) + assert stride > 0, f"{stride=} must be positive" - if left_to_copy > 0: - new_buffer[:left_to_copy] = self._buffer[ - self.start_idx : self._filled_buffer_len - ] + self._leftover = audio_array[stride:] - del self._buffer - self._buffer = new_buffer - - self._filled_buffer_len = left_to_copy - self._start = self._look_back - self._end = self._start + self._streaming_size - - def write_audio(self, audio: np.ndarray) -> None: - put_end_idx = self._filled_buffer_len + len(audio) - - if put_end_idx > self._buffer_size: - self._allocate_new_buffer() - - self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = ( - audio - ) - self._filled_buffer_len += len(audio) - - def read_audio(self) -> np.ndarray | None: - if not self.is_audio_complete: - return None - - audio = self._buffer[self.start_idx : self.end_idx] - self._start = self._end - self._end += self._streaming_size - - return audio + yield StreamingInput( + TokensPrompt( + prompt_token_ids=next_tokens, + multi_modal_data={"audio": (frame, None)}, + ) + ) @MULTIMODAL_REGISTRY.register_processor( @@ -234,7 +229,7 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim ) audio_config = self.tokenizer.instruct.audio_encoder.audio_config - self.n_delay_tokens = audio_config.num_delay_tokens + self.n_delay_tokens = audio_config.get_num_delay_tokens() # for realtime transcription @classmethod @@ -248,45 +243,47 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim audio_encoder = tokenizer.instruct.audio_encoder config = audio_encoder.audio_config - buffer = VoxtralRealtimeBuffer(config) - is_first_yield = True + # Get prompt tokens (streaming prefix tokens) without encoding audio + prompt_tokens = ( + tokenizer.instruct.start() + audio_encoder.encode_streaming_tokens() + ) - async for audio in audio_stream: - buffer.write_audio(audio) + # Get left/right padding audio + left_pad, right_pad = audio_encoder.get_padding_audio() - while (new_audio := buffer.read_audio()) is not None: - if is_first_yield: - # make sure that input_stream is empty - assert input_stream.empty() + buffer = VoxtralRealtimeBuffer(config, prompt_tokens) - audio = Audio(new_audio, config.sampling_rate, format="wav") + # Feed audio with padding into buffer in background + async def feed_audio(): + yielded_first_chunk = False + async for audio_chunk in audio_stream: + if not yielded_first_chunk: + yielded_first_chunk = True + # Prepend left padding before first real audio + await buffer.append_audio(left_pad.audio_array) + await buffer.append_audio(audio_chunk) + # Append right padding at the end + await buffer.append_audio(right_pad.audio_array) + await buffer.append_audio(None) # signal end - request = TranscriptionRequest( - streaming=StreamingMode.ONLINE, - audio=RawAudio.from_audio(audio), - language=None, - ) - # mistral tokenizer takes care - # of preparing the first prompt inputs - # and does some left-silence padding - # for improved performance - audio_enc = tokenizer.mistral.encode_transcription(request) - - token_ids = audio_enc.tokens - new_audio = audio_enc.audios[0].audio_array - - is_first_yield = False - else: - # pop last element from input_stream - all_outputs = await asyncio.wait_for( - input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S - ) - token_ids = all_outputs[-1:] - - multi_modal_data = {"audio": (new_audio, None)} - yield TokensPrompt( - prompt_token_ids=token_ids, multi_modal_data=multi_modal_data + # Feed output tokens back into buffer in background + async def feed_tokens(): + while True: + all_outputs = await asyncio.wait_for( + input_stream.get(), + timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S, ) + await buffer.append_tokens(all_outputs[-1:]) + + audio_task = asyncio.create_task(feed_audio()) + token_task = asyncio.create_task(feed_tokens()) + + try: + async for streaming_input in buffer.get_input_stream(): + yield streaming_input.prompt + finally: + audio_task.cancel() + token_task.cancel() @property def audio_config(self):