[Voxtral Realtime] Refactor & Improve buffering logic (#34428)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2026-02-12 18:46:43 +01:00
committed by GitHub
parent 1100a97621
commit 6c0baee610
7 changed files with 121 additions and 224 deletions

View File

@@ -155,9 +155,7 @@ class VoxtralProcessorAdapter:
assert audio.ndim == 1
if not self._audio_processor.audio_config.is_streaming:
audio = self._audio_processor.pad(
audio, self.sampling_rate, is_online_streaming=False
)
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id

View File

@@ -3,7 +3,7 @@
import asyncio
import math
from collections.abc import AsyncGenerator, Mapping
from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping
from typing import Literal
import numpy as np
@@ -18,7 +18,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.data import PromptType, StreamingInput, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
from vllm.model_executor.models.voxtral import (
@@ -47,8 +47,6 @@ from .utils import (
logger = init_logger(__name__)
_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30
class VoxtralRealtimeMultiModalProcessor(VoxtralMultiModalProcessor):
def __init__(
@@ -130,84 +128,81 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
class VoxtralRealtimeBuffer:
def __init__(self, config: AudioConfig) -> None:
def __init__(self, config: AudioConfig, prompt_tokens: list[int]) -> None:
self._config = config
self._look_ahead_in_ms = config.streaming_look_ahead_ms
self._look_back_in_ms = config.streaming_look_back_ms
_look_ahead_in_ms = self._config.streaming_look_ahead_ms
_look_back_in_ms = self._config.streaming_look_back_ms
self._look_ahead_in_samples = self._ms_to_samples(_look_ahead_in_ms)
self._look_back_in_samples = self._ms_to_samples(_look_back_in_ms)
self._sampling_rate = self._config.sampling_rate
# None signals the end
self._audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
self._leftover: np.ndarray | None = None
self._token_queue: asyncio.Queue[int] = asyncio.Queue()
self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms)
self._look_back = self._get_len_in_samples(self._look_back_in_ms)
self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate)
self._initial_end = len(prompt_tokens) * self._config.raw_audio_length_per_tok
for token in prompt_tokens:
self._token_queue.put_nowait(token)
# mutable objects
streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms)
self._start = 0
self._end = streaming_delay + self._streaming_size
def _generate_frame_size_and_num_tokens(self) -> Iterator[tuple[int, int]]:
streaming_step_size = self._ms_to_samples(1000 / self._config.frame_rate)
start = 0
end = self._initial_end
while True:
frame_start = max(start - self._look_back_in_samples, 0)
frame_end = end + self._look_ahead_in_samples
frame_size = frame_end - frame_start
num_tokens = (end - start) / self._config.raw_audio_length_per_tok
assert num_tokens.is_integer()
yield frame_size, int(num_tokens)
start = end
end += streaming_step_size
# always pre-allocate 30 second buffers
self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate
self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32)
self._filled_buffer_len = 0
def _ms_to_samples(self, ms: float) -> int:
len_ = self._config.sampling_rate * ms / 1000
assert len_.is_integer(), len_
return int(len_)
@property
def start_idx(self):
return max(self._start - self._look_back, 0)
async def append_audio(self, audio_array: np.ndarray | None) -> None:
await self._audio_queue.put(audio_array)
@property
def end_idx(self):
return self._end + self._look_ahead
async def append_tokens(self, tokens: Iterable[int]) -> None:
for token in tokens:
await self._token_queue.put(token)
@property
def is_audio_complete(self) -> bool:
return self._filled_buffer_len >= self.end_idx
async def get_input_stream(self) -> AsyncGenerator[StreamingInput]:
for frame_size, num_tokens in self._generate_frame_size_and_num_tokens():
next_tokens = [await self._token_queue.get() for _ in range(num_tokens)]
def _get_len_in_samples(self, len_in_ms: float) -> int:
_len_in_s = self._sampling_rate * len_in_ms / 1000
assert _len_in_s.is_integer(), _len_in_s
len_in_s = int(_len_in_s)
audio_arrays: list[np.ndarray] = (
[self._leftover] if self._leftover is not None else []
)
while sum(len(arr) for arr in audio_arrays) < frame_size:
arr = await self._audio_queue.get()
if arr is None:
return
audio_arrays.append(arr)
return len_in_s
audio_array = np.concatenate(audio_arrays)
frame = audio_array[:frame_size]
def _allocate_new_buffer(self) -> None:
# allocate new buffer
new_buffer = np.empty(self._buffer_size, dtype=np.float32)
left_to_copy = max(self._filled_buffer_len - self.start_idx, 0)
# The current stride took look_ahead_in_samples audio of the next sample
# In addition the next sample will take look_back_in_samples audio of
# the current sample => So let's put both of this into the leftover
stride = (
frame_size - self._look_ahead_in_samples - self._look_back_in_samples
)
assert stride > 0, f"{stride=} must be positive"
if left_to_copy > 0:
new_buffer[:left_to_copy] = self._buffer[
self.start_idx : self._filled_buffer_len
]
self._leftover = audio_array[stride:]
del self._buffer
self._buffer = new_buffer
self._filled_buffer_len = left_to_copy
self._start = self._look_back
self._end = self._start + self._streaming_size
def write_audio(self, audio: np.ndarray) -> None:
put_end_idx = self._filled_buffer_len + len(audio)
if put_end_idx > self._buffer_size:
self._allocate_new_buffer()
self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = (
audio
)
self._filled_buffer_len += len(audio)
def read_audio(self) -> np.ndarray | None:
if not self.is_audio_complete:
return None
audio = self._buffer[self.start_idx : self.end_idx]
self._start = self._end
self._end += self._streaming_size
return audio
yield StreamingInput(
TokensPrompt(
prompt_token_ids=next_tokens,
multi_modal_data={"audio": (frame, None)},
)
)
@MULTIMODAL_REGISTRY.register_processor(
@@ -234,7 +229,7 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
)
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
self.n_delay_tokens = audio_config.num_delay_tokens
self.n_delay_tokens = audio_config.get_num_delay_tokens()
# for realtime transcription
@classmethod
@@ -248,45 +243,47 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
audio_encoder = tokenizer.instruct.audio_encoder
config = audio_encoder.audio_config
buffer = VoxtralRealtimeBuffer(config)
is_first_yield = True
# Get prompt tokens (streaming prefix tokens) without encoding audio
prompt_tokens = (
tokenizer.instruct.start() + audio_encoder.encode_streaming_tokens()
)
async for audio in audio_stream:
buffer.write_audio(audio)
# Get left/right padding audio
left_pad, right_pad = audio_encoder.get_padding_audio()
while (new_audio := buffer.read_audio()) is not None:
if is_first_yield:
# make sure that input_stream is empty
assert input_stream.empty()
buffer = VoxtralRealtimeBuffer(config, prompt_tokens)
audio = Audio(new_audio, config.sampling_rate, format="wav")
# Feed audio with padding into buffer in background
async def feed_audio():
yielded_first_chunk = False
async for audio_chunk in audio_stream:
if not yielded_first_chunk:
yielded_first_chunk = True
# Prepend left padding before first real audio
await buffer.append_audio(left_pad.audio_array)
await buffer.append_audio(audio_chunk)
# Append right padding at the end
await buffer.append_audio(right_pad.audio_array)
await buffer.append_audio(None) # signal end
request = TranscriptionRequest(
streaming=StreamingMode.ONLINE,
audio=RawAudio.from_audio(audio),
language=None,
)
# mistral tokenizer takes care
# of preparing the first prompt inputs
# and does some left-silence padding
# for improved performance
audio_enc = tokenizer.mistral.encode_transcription(request)
token_ids = audio_enc.tokens
new_audio = audio_enc.audios[0].audio_array
is_first_yield = False
else:
# pop last element from input_stream
all_outputs = await asyncio.wait_for(
input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S
)
token_ids = all_outputs[-1:]
multi_modal_data = {"audio": (new_audio, None)}
yield TokensPrompt(
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
# Feed output tokens back into buffer in background
async def feed_tokens():
while True:
all_outputs = await asyncio.wait_for(
input_stream.get(),
timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S,
)
await buffer.append_tokens(all_outputs[-1:])
audio_task = asyncio.create_task(feed_audio())
token_task = asyncio.create_task(feed_tokens())
try:
async for streaming_input in buffer.get_input_stream():
yield streaming_input.prompt
finally:
audio_task.cancel()
token_task.cancel()
@property
def audio_config(self):