[Voxtral Realtime] Refactor & Improve buffering logic (#34428)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
1100a97621
commit
6c0baee610
@@ -155,9 +155,7 @@ class VoxtralProcessorAdapter:
|
||||
assert audio.ndim == 1
|
||||
|
||||
if not self._audio_processor.audio_config.is_streaming:
|
||||
audio = self._audio_processor.pad(
|
||||
audio, self.sampling_rate, is_online_streaming=False
|
||||
)
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
|
||||
audio_tokens = [self.begin_audio_token_id] + [
|
||||
self.audio_token_id
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
@@ -18,7 +18,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.inputs.data import PromptType, StreamingInput, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
|
||||
from vllm.model_executor.models.voxtral import (
|
||||
@@ -47,8 +47,6 @@ from .utils import (
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30
|
||||
|
||||
|
||||
class VoxtralRealtimeMultiModalProcessor(VoxtralMultiModalProcessor):
|
||||
def __init__(
|
||||
@@ -130,84 +128,81 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
|
||||
|
||||
|
||||
class VoxtralRealtimeBuffer:
|
||||
def __init__(self, config: AudioConfig) -> None:
|
||||
def __init__(self, config: AudioConfig, prompt_tokens: list[int]) -> None:
|
||||
self._config = config
|
||||
|
||||
self._look_ahead_in_ms = config.streaming_look_ahead_ms
|
||||
self._look_back_in_ms = config.streaming_look_back_ms
|
||||
_look_ahead_in_ms = self._config.streaming_look_ahead_ms
|
||||
_look_back_in_ms = self._config.streaming_look_back_ms
|
||||
self._look_ahead_in_samples = self._ms_to_samples(_look_ahead_in_ms)
|
||||
self._look_back_in_samples = self._ms_to_samples(_look_back_in_ms)
|
||||
|
||||
self._sampling_rate = self._config.sampling_rate
|
||||
# None signals the end
|
||||
self._audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
|
||||
self._leftover: np.ndarray | None = None
|
||||
self._token_queue: asyncio.Queue[int] = asyncio.Queue()
|
||||
|
||||
self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms)
|
||||
self._look_back = self._get_len_in_samples(self._look_back_in_ms)
|
||||
self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate)
|
||||
self._initial_end = len(prompt_tokens) * self._config.raw_audio_length_per_tok
|
||||
for token in prompt_tokens:
|
||||
self._token_queue.put_nowait(token)
|
||||
|
||||
# mutable objects
|
||||
streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms)
|
||||
self._start = 0
|
||||
self._end = streaming_delay + self._streaming_size
|
||||
def _generate_frame_size_and_num_tokens(self) -> Iterator[tuple[int, int]]:
|
||||
streaming_step_size = self._ms_to_samples(1000 / self._config.frame_rate)
|
||||
start = 0
|
||||
end = self._initial_end
|
||||
while True:
|
||||
frame_start = max(start - self._look_back_in_samples, 0)
|
||||
frame_end = end + self._look_ahead_in_samples
|
||||
frame_size = frame_end - frame_start
|
||||
num_tokens = (end - start) / self._config.raw_audio_length_per_tok
|
||||
assert num_tokens.is_integer()
|
||||
yield frame_size, int(num_tokens)
|
||||
start = end
|
||||
end += streaming_step_size
|
||||
|
||||
# always pre-allocate 30 second buffers
|
||||
self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate
|
||||
self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32)
|
||||
self._filled_buffer_len = 0
|
||||
def _ms_to_samples(self, ms: float) -> int:
|
||||
len_ = self._config.sampling_rate * ms / 1000
|
||||
assert len_.is_integer(), len_
|
||||
return int(len_)
|
||||
|
||||
@property
|
||||
def start_idx(self):
|
||||
return max(self._start - self._look_back, 0)
|
||||
async def append_audio(self, audio_array: np.ndarray | None) -> None:
|
||||
await self._audio_queue.put(audio_array)
|
||||
|
||||
@property
|
||||
def end_idx(self):
|
||||
return self._end + self._look_ahead
|
||||
async def append_tokens(self, tokens: Iterable[int]) -> None:
|
||||
for token in tokens:
|
||||
await self._token_queue.put(token)
|
||||
|
||||
@property
|
||||
def is_audio_complete(self) -> bool:
|
||||
return self._filled_buffer_len >= self.end_idx
|
||||
async def get_input_stream(self) -> AsyncGenerator[StreamingInput]:
|
||||
for frame_size, num_tokens in self._generate_frame_size_and_num_tokens():
|
||||
next_tokens = [await self._token_queue.get() for _ in range(num_tokens)]
|
||||
|
||||
def _get_len_in_samples(self, len_in_ms: float) -> int:
|
||||
_len_in_s = self._sampling_rate * len_in_ms / 1000
|
||||
assert _len_in_s.is_integer(), _len_in_s
|
||||
len_in_s = int(_len_in_s)
|
||||
audio_arrays: list[np.ndarray] = (
|
||||
[self._leftover] if self._leftover is not None else []
|
||||
)
|
||||
while sum(len(arr) for arr in audio_arrays) < frame_size:
|
||||
arr = await self._audio_queue.get()
|
||||
if arr is None:
|
||||
return
|
||||
audio_arrays.append(arr)
|
||||
|
||||
return len_in_s
|
||||
audio_array = np.concatenate(audio_arrays)
|
||||
frame = audio_array[:frame_size]
|
||||
|
||||
def _allocate_new_buffer(self) -> None:
|
||||
# allocate new buffer
|
||||
new_buffer = np.empty(self._buffer_size, dtype=np.float32)
|
||||
left_to_copy = max(self._filled_buffer_len - self.start_idx, 0)
|
||||
# The current stride took look_ahead_in_samples audio of the next sample
|
||||
# In addition the next sample will take look_back_in_samples audio of
|
||||
# the current sample => So let's put both of this into the leftover
|
||||
stride = (
|
||||
frame_size - self._look_ahead_in_samples - self._look_back_in_samples
|
||||
)
|
||||
assert stride > 0, f"{stride=} must be positive"
|
||||
|
||||
if left_to_copy > 0:
|
||||
new_buffer[:left_to_copy] = self._buffer[
|
||||
self.start_idx : self._filled_buffer_len
|
||||
]
|
||||
self._leftover = audio_array[stride:]
|
||||
|
||||
del self._buffer
|
||||
self._buffer = new_buffer
|
||||
|
||||
self._filled_buffer_len = left_to_copy
|
||||
self._start = self._look_back
|
||||
self._end = self._start + self._streaming_size
|
||||
|
||||
def write_audio(self, audio: np.ndarray) -> None:
|
||||
put_end_idx = self._filled_buffer_len + len(audio)
|
||||
|
||||
if put_end_idx > self._buffer_size:
|
||||
self._allocate_new_buffer()
|
||||
|
||||
self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = (
|
||||
audio
|
||||
)
|
||||
self._filled_buffer_len += len(audio)
|
||||
|
||||
def read_audio(self) -> np.ndarray | None:
|
||||
if not self.is_audio_complete:
|
||||
return None
|
||||
|
||||
audio = self._buffer[self.start_idx : self.end_idx]
|
||||
self._start = self._end
|
||||
self._end += self._streaming_size
|
||||
|
||||
return audio
|
||||
yield StreamingInput(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=next_tokens,
|
||||
multi_modal_data={"audio": (frame, None)},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
@@ -234,7 +229,7 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
|
||||
)
|
||||
|
||||
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
|
||||
self.n_delay_tokens = audio_config.num_delay_tokens
|
||||
self.n_delay_tokens = audio_config.get_num_delay_tokens()
|
||||
|
||||
# for realtime transcription
|
||||
@classmethod
|
||||
@@ -248,45 +243,47 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
|
||||
audio_encoder = tokenizer.instruct.audio_encoder
|
||||
config = audio_encoder.audio_config
|
||||
|
||||
buffer = VoxtralRealtimeBuffer(config)
|
||||
is_first_yield = True
|
||||
# Get prompt tokens (streaming prefix tokens) without encoding audio
|
||||
prompt_tokens = (
|
||||
tokenizer.instruct.start() + audio_encoder.encode_streaming_tokens()
|
||||
)
|
||||
|
||||
async for audio in audio_stream:
|
||||
buffer.write_audio(audio)
|
||||
# Get left/right padding audio
|
||||
left_pad, right_pad = audio_encoder.get_padding_audio()
|
||||
|
||||
while (new_audio := buffer.read_audio()) is not None:
|
||||
if is_first_yield:
|
||||
# make sure that input_stream is empty
|
||||
assert input_stream.empty()
|
||||
buffer = VoxtralRealtimeBuffer(config, prompt_tokens)
|
||||
|
||||
audio = Audio(new_audio, config.sampling_rate, format="wav")
|
||||
# Feed audio with padding into buffer in background
|
||||
async def feed_audio():
|
||||
yielded_first_chunk = False
|
||||
async for audio_chunk in audio_stream:
|
||||
if not yielded_first_chunk:
|
||||
yielded_first_chunk = True
|
||||
# Prepend left padding before first real audio
|
||||
await buffer.append_audio(left_pad.audio_array)
|
||||
await buffer.append_audio(audio_chunk)
|
||||
# Append right padding at the end
|
||||
await buffer.append_audio(right_pad.audio_array)
|
||||
await buffer.append_audio(None) # signal end
|
||||
|
||||
request = TranscriptionRequest(
|
||||
streaming=StreamingMode.ONLINE,
|
||||
audio=RawAudio.from_audio(audio),
|
||||
language=None,
|
||||
)
|
||||
# mistral tokenizer takes care
|
||||
# of preparing the first prompt inputs
|
||||
# and does some left-silence padding
|
||||
# for improved performance
|
||||
audio_enc = tokenizer.mistral.encode_transcription(request)
|
||||
|
||||
token_ids = audio_enc.tokens
|
||||
new_audio = audio_enc.audios[0].audio_array
|
||||
|
||||
is_first_yield = False
|
||||
else:
|
||||
# pop last element from input_stream
|
||||
all_outputs = await asyncio.wait_for(
|
||||
input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
)
|
||||
token_ids = all_outputs[-1:]
|
||||
|
||||
multi_modal_data = {"audio": (new_audio, None)}
|
||||
yield TokensPrompt(
|
||||
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
|
||||
# Feed output tokens back into buffer in background
|
||||
async def feed_tokens():
|
||||
while True:
|
||||
all_outputs = await asyncio.wait_for(
|
||||
input_stream.get(),
|
||||
timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S,
|
||||
)
|
||||
await buffer.append_tokens(all_outputs[-1:])
|
||||
|
||||
audio_task = asyncio.create_task(feed_audio())
|
||||
token_task = asyncio.create_task(feed_tokens())
|
||||
|
||||
try:
|
||||
async for streaming_input in buffer.get_input_stream():
|
||||
yield streaming_input.prompt
|
||||
finally:
|
||||
audio_task.cancel()
|
||||
token_task.cancel()
|
||||
|
||||
@property
|
||||
def audio_config(self):
|
||||
|
||||
Reference in New Issue
Block a user