[Realtime API] Adds minimal realtime API based on websockets (#33187)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Patrick von Platen
2026-01-30 11:41:29 +01:00
committed by GitHub
parent 1a7894dbdf
commit 10152d2194
21 changed files with 1316 additions and 48 deletions

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Mapping, MutableSequence
import asyncio
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, MutableSequence
from contextlib import ExitStack, contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
@@ -1015,6 +1016,37 @@ class SupportsQuant:
return None
@runtime_checkable
class SupportsRealtime(Protocol):
"""The interface required for all models that support transcription."""
supports_realtime: ClassVar[Literal[True]] = True
@classmethod
async def buffer_realtime_audio(
cls,
audio_stream: AsyncGenerator[np.ndarray, None],
input_stream: asyncio.Queue[list[int]],
model_config: ModelConfig,
) -> AsyncGenerator[PromptType, None]: ...
@overload
def supports_realtime(
model: type[object],
) -> TypeIs[type[SupportsRealtime]]: ...
@overload
def supports_realtime(model: object) -> TypeIs[SupportsRealtime]: ...
def supports_realtime(
model: type[object] | object,
) -> TypeIs[type[SupportsRealtime]] | TypeIs[SupportsRealtime]:
return getattr(model, "supports_realtime", False)
@runtime_checkable
class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription."""

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
@@ -20,7 +19,6 @@ from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import (
Audio,
AudioEncoder,
TranscriptionFormat,
)
from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput
@@ -163,19 +161,10 @@ class VoxtralProcessorAdapter:
assert isinstance(audio, np.ndarray)
assert audio.ndim == 1
# pad if necessary
# TODO(Patrick) - remove once mistral-common is bumped
if (
self._audio_processor.audio_config.transcription_format
!= TranscriptionFormat.STREAMING
):
sig = inspect.signature(self._audio_processor.pad)
if "is_online_streaming" in sig.parameters:
audio = self._audio_processor.pad(
audio, self.sampling_rate, is_online_streaming=False
)
else:
audio = self._audio_processor.pad(audio, self.sampling_rate)
if not self._audio_processor.audio_config.is_streaming:
audio = self._audio_processor.pad(
audio, self.sampling_rate, is_online_streaming=False
)
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id

View File

@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import math
from collections.abc import Mapping
from collections.abc import AsyncGenerator, Mapping
from typing import Literal, cast
import numpy as np
@@ -12,12 +13,14 @@ from mistral_common.protocol.transcription.request import (
StreamingMode,
TranscriptionRequest,
)
from mistral_common.tokens.tokenizers.audio import Audio
from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
from vllm.model_executor.models.voxtral import (
VoxtralDummyInputsBuilder,
VoxtralForConditionalGeneration,
@@ -44,6 +47,8 @@ from .utils import (
logger = init_logger(__name__)
_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30
class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor):
def __init__(
@@ -124,29 +129,164 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
return (base.unsqueeze(1) + offsets).view(-1)
class VoxtralRealtimeBuffer:
def __init__(self, config: AudioConfig) -> None:
self._config = config
self._look_ahead_in_ms = config.streaming_look_ahead_ms
self._look_back_in_ms = config.streaming_look_back_ms
self._sampling_rate = self._config.sampling_rate
self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms)
self._look_back = self._get_len_in_samples(self._look_back_in_ms)
self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate)
# mutable objects
streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms)
self._start = 0
self._end = streaming_delay + self._streaming_size
# always pre-allocate 30 second buffers
self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate
self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32)
self._filled_buffer_len = 0
@property
def start_idx(self):
return max(self._start - self._look_back, 0)
@property
def end_idx(self):
return self._end + self._look_ahead
@property
def is_audio_complete(self) -> bool:
return self._filled_buffer_len >= self.end_idx
def _get_len_in_samples(self, len_in_ms: float) -> int:
_len_in_s = self._sampling_rate * len_in_ms / 1000
assert _len_in_s.is_integer(), _len_in_s
len_in_s = int(_len_in_s)
return len_in_s
def _allocate_new_buffer(self) -> None:
# allocate new buffer
new_buffer = np.empty(self._buffer_size, dtype=np.float32)
left_to_copy = max(self._filled_buffer_len - self.start_idx, 0)
if left_to_copy > 0:
new_buffer[:left_to_copy] = self._buffer[
self.start_idx : self._filled_buffer_len
]
del self._buffer
self._buffer = new_buffer
self._filled_buffer_len = left_to_copy
self._start = self._look_back
self._end = self._start + self._streaming_size
def write_audio(self, audio: np.ndarray) -> None:
put_end_idx = self._filled_buffer_len + len(audio)
if put_end_idx > self._buffer_size:
self._allocate_new_buffer()
self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = (
audio
)
self._filled_buffer_len += len(audio)
def read_audio(self) -> np.ndarray | None:
if not self.is_audio_complete:
return None
audio = self._buffer[self.start_idx : self.end_idx]
self._start = self._end
self._end += self._streaming_size
return audio
@MULTIMODAL_REGISTRY.register_processor(
VoxtralStreamingMultiModalProcessor,
info=VoxtralProcessingInfo,
dummy_inputs=VoxtralDummyInputsBuilder,
)
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
@support_torch_compile
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration, SupportsRealtime):
requires_raw_input_tokens = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert (
not vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
), (
"Voxtral streaming doesn't support full cudagraphs yet. "
"Please use PIECEWISE."
)
self.time_embedding: TimeEmbedding = TimeEmbedding(
dim=self.config.text_config.hidden_size
)
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
_n_delay_tokens = (
audio_config.frame_rate * audio_config.transcription_delay_ms / 1000
)
assert _n_delay_tokens.is_integer(), (
f"n_delay_tokens must be integer, got {_n_delay_tokens}"
)
self.n_delay_tokens = audio_config.num_delay_tokens
self.n_delay_tokens = int(_n_delay_tokens)
# for realtime transcription
@classmethod
async def buffer_realtime_audio(
cls,
audio_stream: AsyncGenerator[np.ndarray, None],
input_stream: asyncio.Queue[list[int]],
model_config: ModelConfig,
) -> AsyncGenerator[PromptType, None]:
tokenizer = cached_tokenizer_from_config(model_config)
audio_encoder = tokenizer.instruct.audio_encoder
config = audio_encoder.audio_config
buffer = VoxtralRealtimeBuffer(config)
is_first_yield = True
async for audio in audio_stream:
buffer.write_audio(audio)
while (new_audio := buffer.read_audio()) is not None:
if is_first_yield:
# make sure that input_stream is empty
assert input_stream.empty()
audio = Audio(new_audio, config.sampling_rate, format="wav")
request = TranscriptionRequest(
streaming=StreamingMode.ONLINE,
audio=RawAudio.from_audio(audio),
language=None,
)
# mistral tokenizer takes care
# of preparing the first prompt inputs
# and does some left-silence padding
# for improved performance
audio_enc = tokenizer.mistral.encode_transcription(request)
token_ids = audio_enc.tokens
new_audio = audio_enc.audios[0].audio_array
is_first_yield = False
else:
# pop last element from input_stream
all_outputs = await asyncio.wait_for(
input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S
)
token_ids = all_outputs[-1:]
multi_modal_data = {"audio": (new_audio, None)}
yield TokensPrompt(
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
)
@property
def audio_config(self):
@@ -205,8 +345,9 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
# sum pool text and audio embeddings
inputs_embeds = audio_text_embeds + text_embeds
time_tensor = torch.tensor(
[self.n_delay_tokens],
time_tensor = torch.full(
(1,),
fill_value=self.n_delay_tokens,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)