[Voxtral Realtime] Refactor & Improve buffering logic (#34428)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2026-02-12 18:46:43 +01:00
committed by GitHub
parent 1100a97621
commit 6c0baee610
7 changed files with 121 additions and 224 deletions

View File

@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0 pyzmq >= 25.0.0
msgspec msgspec
gguf >= 0.17.0 gguf >= 0.17.0
mistral_common[image] >= 1.9.0 mistral_common[image] >= 1.9.1
opencv-python-headless >= 4.13.0 # required for video IO opencv-python-headless >= 4.13.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12

View File

@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test timm # required for internvl test
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.9.0 # required for voxtral test mistral_common[image,audio] >= 1.9.1 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
opencv-python-headless >= 4.13.0 # required for video test opencv-python-headless >= 4.13.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test

View File

@@ -30,7 +30,7 @@ torchaudio==2.10.0
torchvision==0.25.0 torchvision==0.25.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.9.0 # required for voxtral test mistral_common[image,audio] >= 1.9.1 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless >= 4.13.0 # required for video test opencv-python-headless >= 4.13.0 # required for video test

View File

@@ -499,7 +499,7 @@ mbstrdecoder==1.1.3
# typepy # typepy
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.9.0 mistral-common==1.9.1
# via -r requirements/test.in # via -r requirements/test.in
mlflow==2.22.0 mlflow==2.22.0
# via terratorch # via terratorch

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from dataclasses import asdict from dataclasses import asdict
import pytest import pytest
@@ -10,14 +9,13 @@ from mistral_common.protocol.transcription.request import (
StreamingMode, StreamingMode,
TranscriptionRequest, TranscriptionRequest,
) )
from mistral_common.tokens.tokenizers.audio import AudioConfig
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from vllm import LLM, EngineArgs, SamplingParams from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt from vllm.model_executor.models.voxtral_realtime import VoxtralRealtimeBuffer
from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602" MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
ENGINE_CONFIG = dict( ENGINE_CONFIG = dict(
@@ -114,136 +112,40 @@ def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
assert texts == EXPECTED_TEXT assert texts == EXPECTED_TEXT
class RealTimeAudioInput:
"""
This class is used to stream an audio file just as
if it would be streamed in real-time.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
self._tokenizer = tokenizer
self._config: AudioConfig = (
self._tokenizer.instruct_tokenizer.audio_encoder.audio_config
)
self._look_ahead_in_ms = self._config.streaming_look_ahead_ms
self._look_back_in_ms = self._config.streaming_look_back_ms
self._sampling_rate = self._config.sampling_rate
self._audio: Audio | None = None
# mutable objects
self._start = 0
n_left_pad_samples = (
self._config.raw_audio_length_per_tok * self._config.n_left_pad_tokens
)
self._end = self.streaming_delay + n_left_pad_samples + self.streaming_size
self._queue: asyncio.Queue[StreamingInput | None] = asyncio.Queue()
@classmethod
async def create(cls, audio: Audio, tokenizer: MistralTokenizer):
self = cls(tokenizer)
# we're doing "OFFLINE" encoding here to right & left pad the audio since
# we have access to the whole audio
# if we'd do an actual online realtime streaming application we
# should instead pass `StreamingMode.ONLINE`
req = TranscriptionRequest(
streaming=StreamingMode.OFFLINE,
audio=RawAudio.from_audio(audio),
language=None,
)
audio_enc = self._tokenizer.encode_transcription(req)
self._audio = audio_enc.audios[0]
# add first request
await self.add_tokens(audio_enc.tokens)
return self
@property
def look_ahead(self) -> int:
return self._get_len_in_samples(self._look_ahead_in_ms)
@property
def look_back(self) -> int:
return self._get_len_in_samples(self._look_back_in_ms)
@property
def streaming_delay(self) -> int:
return self._get_len_in_samples(self._config.transcription_delay_ms)
@property
def streaming_size(self) -> int:
stream_size_in_ms = 1000 / self._config.frame_rate
return self._get_len_in_samples(stream_size_in_ms)
def _get_len_in_samples(self, len_in_ms: float) -> int:
_len_in_s = self._sampling_rate * len_in_ms / 1000
assert _len_in_s.is_integer(), _len_in_s
len_in_s = int(_len_in_s)
return len_in_s
async def add_tokens(self, tokens: list[int]) -> None:
assert self._audio is not None
if self._start >= len(self._audio.audio_array):
self.stop()
return
_end = self._end + self.look_ahead
_start = max(0, self._start - self.look_back)
multi_modal_data = {"audio": (self._audio.audio_array[_start:_end], None)}
prompt = TokensPrompt(
prompt_token_ids=tokens, multi_modal_data=multi_modal_data
)
await self._queue.put(StreamingInput(prompt))
# increase
self._start = self._end
self._end = self._end + self.streaming_size
def stop(self):
self._queue.put_nowait(None)
async def generator(self):
while (item := await self._queue.get()) is not None:
yield item
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine): async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine):
sampling_params = SamplingParams(temperature=0.0, max_tokens=1) sampling_params = SamplingParams(temperature=0.0, max_tokens=1)
audio_config = tokenizer.instruct_tokenizer.audio_encoder.audio_config
output_tokens_list = [] output_tokens_list = []
for i, audio_asset in enumerate(audio_assets): for i, audio_asset in enumerate(audio_assets):
output_tokens = [] output_tokens = []
audio = Audio.from_file(audio_asset.get_local_path(), strict=False) audio = Audio.from_file(audio_asset.get_local_path(), strict=False)
streaming_input = await RealTimeAudioInput.create(
audio=audio, tokenizer=tokenizer req = TranscriptionRequest(
streaming=StreamingMode.OFFLINE,
audio=RawAudio.from_audio(audio),
language=None,
) )
audio_enc = tokenizer.encode_transcription(req)
buffer = VoxtralRealtimeBuffer(audio_config, audio_enc.tokens)
await buffer.append_audio(audio_enc.audios[0].audio_array)
await buffer.append_audio(None)
request_id = f"session-{i}" request_id = f"session-{i}"
async for resp in async_engine.generate( async for resp in async_engine.generate(
prompt=streaming_input.generator(), prompt=buffer.get_input_stream(),
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
): ):
tokens = resp.outputs[0].token_ids[-1:] tokens = resp.outputs[0].token_ids[-1:]
output_tokens.extend(tokens) output_tokens.extend(tokens)
await streaming_input.add_tokens(tokens) await buffer.append_tokens(tokens)
output_tokens_list.append(output_tokens) output_tokens_list.append(output_tokens)
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list] texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my") texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
assert texts == EXPECTED_TEXT assert texts == EXPECTED_TEXT

View File

@@ -155,9 +155,7 @@ class VoxtralProcessorAdapter:
assert audio.ndim == 1 assert audio.ndim == 1
if not self._audio_processor.audio_config.is_streaming: if not self._audio_processor.audio_config.is_streaming:
audio = self._audio_processor.pad( audio = self._audio_processor.pad(audio, self.sampling_rate)
audio, self.sampling_rate, is_online_streaming=False
)
audio_tokens = [self.begin_audio_token_id] + [ audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id self.audio_token_id

View File

@@ -3,7 +3,7 @@
import asyncio import asyncio
import math import math
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping
from typing import Literal from typing import Literal
import numpy as np import numpy as np
@@ -18,7 +18,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, StreamingInput, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
from vllm.model_executor.models.voxtral import ( from vllm.model_executor.models.voxtral import (
@@ -47,8 +47,6 @@ from .utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30
class VoxtralRealtimeMultiModalProcessor(VoxtralMultiModalProcessor): class VoxtralRealtimeMultiModalProcessor(VoxtralMultiModalProcessor):
def __init__( def __init__(
@@ -130,84 +128,81 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
class VoxtralRealtimeBuffer: class VoxtralRealtimeBuffer:
def __init__(self, config: AudioConfig) -> None: def __init__(self, config: AudioConfig, prompt_tokens: list[int]) -> None:
self._config = config self._config = config
self._look_ahead_in_ms = config.streaming_look_ahead_ms _look_ahead_in_ms = self._config.streaming_look_ahead_ms
self._look_back_in_ms = config.streaming_look_back_ms _look_back_in_ms = self._config.streaming_look_back_ms
self._look_ahead_in_samples = self._ms_to_samples(_look_ahead_in_ms)
self._look_back_in_samples = self._ms_to_samples(_look_back_in_ms)
self._sampling_rate = self._config.sampling_rate # None signals the end
self._audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
self._leftover: np.ndarray | None = None
self._token_queue: asyncio.Queue[int] = asyncio.Queue()
self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms) self._initial_end = len(prompt_tokens) * self._config.raw_audio_length_per_tok
self._look_back = self._get_len_in_samples(self._look_back_in_ms) for token in prompt_tokens:
self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate) self._token_queue.put_nowait(token)
# mutable objects def _generate_frame_size_and_num_tokens(self) -> Iterator[tuple[int, int]]:
streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms) streaming_step_size = self._ms_to_samples(1000 / self._config.frame_rate)
self._start = 0 start = 0
self._end = streaming_delay + self._streaming_size end = self._initial_end
while True:
frame_start = max(start - self._look_back_in_samples, 0)
frame_end = end + self._look_ahead_in_samples
frame_size = frame_end - frame_start
num_tokens = (end - start) / self._config.raw_audio_length_per_tok
assert num_tokens.is_integer()
yield frame_size, int(num_tokens)
start = end
end += streaming_step_size
# always pre-allocate 30 second buffers def _ms_to_samples(self, ms: float) -> int:
self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate len_ = self._config.sampling_rate * ms / 1000
self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32) assert len_.is_integer(), len_
self._filled_buffer_len = 0 return int(len_)
@property async def append_audio(self, audio_array: np.ndarray | None) -> None:
def start_idx(self): await self._audio_queue.put(audio_array)
return max(self._start - self._look_back, 0)
@property async def append_tokens(self, tokens: Iterable[int]) -> None:
def end_idx(self): for token in tokens:
return self._end + self._look_ahead await self._token_queue.put(token)
@property async def get_input_stream(self) -> AsyncGenerator[StreamingInput]:
def is_audio_complete(self) -> bool: for frame_size, num_tokens in self._generate_frame_size_and_num_tokens():
return self._filled_buffer_len >= self.end_idx next_tokens = [await self._token_queue.get() for _ in range(num_tokens)]
def _get_len_in_samples(self, len_in_ms: float) -> int: audio_arrays: list[np.ndarray] = (
_len_in_s = self._sampling_rate * len_in_ms / 1000 [self._leftover] if self._leftover is not None else []
assert _len_in_s.is_integer(), _len_in_s )
len_in_s = int(_len_in_s) while sum(len(arr) for arr in audio_arrays) < frame_size:
arr = await self._audio_queue.get()
if arr is None:
return
audio_arrays.append(arr)
return len_in_s audio_array = np.concatenate(audio_arrays)
frame = audio_array[:frame_size]
def _allocate_new_buffer(self) -> None: # The current stride took look_ahead_in_samples audio of the next sample
# allocate new buffer # In addition the next sample will take look_back_in_samples audio of
new_buffer = np.empty(self._buffer_size, dtype=np.float32) # the current sample => So let's put both of this into the leftover
left_to_copy = max(self._filled_buffer_len - self.start_idx, 0) stride = (
frame_size - self._look_ahead_in_samples - self._look_back_in_samples
)
assert stride > 0, f"{stride=} must be positive"
if left_to_copy > 0: self._leftover = audio_array[stride:]
new_buffer[:left_to_copy] = self._buffer[
self.start_idx : self._filled_buffer_len
]
del self._buffer yield StreamingInput(
self._buffer = new_buffer TokensPrompt(
prompt_token_ids=next_tokens,
self._filled_buffer_len = left_to_copy multi_modal_data={"audio": (frame, None)},
self._start = self._look_back )
self._end = self._start + self._streaming_size )
def write_audio(self, audio: np.ndarray) -> None:
put_end_idx = self._filled_buffer_len + len(audio)
if put_end_idx > self._buffer_size:
self._allocate_new_buffer()
self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = (
audio
)
self._filled_buffer_len += len(audio)
def read_audio(self) -> np.ndarray | None:
if not self.is_audio_complete:
return None
audio = self._buffer[self.start_idx : self.end_idx]
self._start = self._end
self._end += self._streaming_size
return audio
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
@@ -234,7 +229,7 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
) )
audio_config = self.tokenizer.instruct.audio_encoder.audio_config audio_config = self.tokenizer.instruct.audio_encoder.audio_config
self.n_delay_tokens = audio_config.num_delay_tokens self.n_delay_tokens = audio_config.get_num_delay_tokens()
# for realtime transcription # for realtime transcription
@classmethod @classmethod
@@ -248,45 +243,47 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
audio_encoder = tokenizer.instruct.audio_encoder audio_encoder = tokenizer.instruct.audio_encoder
config = audio_encoder.audio_config config = audio_encoder.audio_config
buffer = VoxtralRealtimeBuffer(config) # Get prompt tokens (streaming prefix tokens) without encoding audio
is_first_yield = True prompt_tokens = (
tokenizer.instruct.start() + audio_encoder.encode_streaming_tokens()
)
async for audio in audio_stream: # Get left/right padding audio
buffer.write_audio(audio) left_pad, right_pad = audio_encoder.get_padding_audio()
while (new_audio := buffer.read_audio()) is not None: buffer = VoxtralRealtimeBuffer(config, prompt_tokens)
if is_first_yield:
# make sure that input_stream is empty
assert input_stream.empty()
audio = Audio(new_audio, config.sampling_rate, format="wav") # Feed audio with padding into buffer in background
async def feed_audio():
yielded_first_chunk = False
async for audio_chunk in audio_stream:
if not yielded_first_chunk:
yielded_first_chunk = True
# Prepend left padding before first real audio
await buffer.append_audio(left_pad.audio_array)
await buffer.append_audio(audio_chunk)
# Append right padding at the end
await buffer.append_audio(right_pad.audio_array)
await buffer.append_audio(None) # signal end
request = TranscriptionRequest( # Feed output tokens back into buffer in background
streaming=StreamingMode.ONLINE, async def feed_tokens():
audio=RawAudio.from_audio(audio), while True:
language=None, all_outputs = await asyncio.wait_for(
) input_stream.get(),
# mistral tokenizer takes care timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S,
# of preparing the first prompt inputs
# and does some left-silence padding
# for improved performance
audio_enc = tokenizer.mistral.encode_transcription(request)
token_ids = audio_enc.tokens
new_audio = audio_enc.audios[0].audio_array
is_first_yield = False
else:
# pop last element from input_stream
all_outputs = await asyncio.wait_for(
input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S
)
token_ids = all_outputs[-1:]
multi_modal_data = {"audio": (new_audio, None)}
yield TokensPrompt(
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
) )
await buffer.append_tokens(all_outputs[-1:])
audio_task = asyncio.create_task(feed_audio())
token_task = asyncio.create_task(feed_tokens())
try:
async for streaming_input in buffer.get_input_stream():
yield streaming_input.prompt
finally:
audio_task.cancel()
token_task.cancel()
@property @property
def audio_config(self): def audio_config(self):