[Realtime API] Adds minimal realtime API based on websockets (#33187)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
committed by
GitHub
parent
1a7894dbdf
commit
10152d2194
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Iterable, Mapping, MutableSequence
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, MutableSequence
|
||||
from contextlib import ExitStack, contextmanager, nullcontext
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -1015,6 +1016,37 @@ class SupportsQuant:
|
||||
return None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsRealtime(Protocol):
|
||||
"""The interface required for all models that support transcription."""
|
||||
|
||||
supports_realtime: ClassVar[Literal[True]] = True
|
||||
|
||||
@classmethod
|
||||
async def buffer_realtime_audio(
|
||||
cls,
|
||||
audio_stream: AsyncGenerator[np.ndarray, None],
|
||||
input_stream: asyncio.Queue[list[int]],
|
||||
model_config: ModelConfig,
|
||||
) -> AsyncGenerator[PromptType, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_realtime(
|
||||
model: type[object],
|
||||
) -> TypeIs[type[SupportsRealtime]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_realtime(model: object) -> TypeIs[SupportsRealtime]: ...
|
||||
|
||||
|
||||
def supports_realtime(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsRealtime]] | TypeIs[SupportsRealtime]:
|
||||
return getattr(model, "supports_realtime", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsTranscription(Protocol):
|
||||
"""The interface required for all models that support transcription."""
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property, partial
|
||||
@@ -20,7 +19,6 @@ from mistral_common.protocol.transcription.request import TranscriptionRequest
|
||||
from mistral_common.tokens.tokenizers.audio import (
|
||||
Audio,
|
||||
AudioEncoder,
|
||||
TranscriptionFormat,
|
||||
)
|
||||
from transformers import BatchFeature, TensorType, WhisperConfig
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
@@ -163,19 +161,10 @@ class VoxtralProcessorAdapter:
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert audio.ndim == 1
|
||||
|
||||
# pad if necessary
|
||||
# TODO(Patrick) - remove once mistral-common is bumped
|
||||
if (
|
||||
self._audio_processor.audio_config.transcription_format
|
||||
!= TranscriptionFormat.STREAMING
|
||||
):
|
||||
sig = inspect.signature(self._audio_processor.pad)
|
||||
if "is_online_streaming" in sig.parameters:
|
||||
audio = self._audio_processor.pad(
|
||||
audio, self.sampling_rate, is_online_streaming=False
|
||||
)
|
||||
else:
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
if not self._audio_processor.audio_config.is_streaming:
|
||||
audio = self._audio_processor.pad(
|
||||
audio, self.sampling_rate, is_online_streaming=False
|
||||
)
|
||||
|
||||
audio_tokens = [self.begin_audio_token_id] + [
|
||||
self.audio_token_id
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Literal, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -12,12 +13,14 @@ from mistral_common.protocol.transcription.request import (
|
||||
StreamingMode,
|
||||
TranscriptionRequest,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.audio import Audio
|
||||
from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
|
||||
from vllm.model_executor.models.voxtral import (
|
||||
VoxtralDummyInputsBuilder,
|
||||
VoxtralForConditionalGeneration,
|
||||
@@ -44,6 +47,8 @@ from .utils import (
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30
|
||||
|
||||
|
||||
class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor):
|
||||
def __init__(
|
||||
@@ -124,29 +129,164 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
|
||||
return (base.unsqueeze(1) + offsets).view(-1)
|
||||
|
||||
|
||||
class VoxtralRealtimeBuffer:
|
||||
def __init__(self, config: AudioConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
self._look_ahead_in_ms = config.streaming_look_ahead_ms
|
||||
self._look_back_in_ms = config.streaming_look_back_ms
|
||||
|
||||
self._sampling_rate = self._config.sampling_rate
|
||||
|
||||
self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms)
|
||||
self._look_back = self._get_len_in_samples(self._look_back_in_ms)
|
||||
self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate)
|
||||
|
||||
# mutable objects
|
||||
streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms)
|
||||
self._start = 0
|
||||
self._end = streaming_delay + self._streaming_size
|
||||
|
||||
# always pre-allocate 30 second buffers
|
||||
self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate
|
||||
self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32)
|
||||
self._filled_buffer_len = 0
|
||||
|
||||
@property
|
||||
def start_idx(self):
|
||||
return max(self._start - self._look_back, 0)
|
||||
|
||||
@property
|
||||
def end_idx(self):
|
||||
return self._end + self._look_ahead
|
||||
|
||||
@property
|
||||
def is_audio_complete(self) -> bool:
|
||||
return self._filled_buffer_len >= self.end_idx
|
||||
|
||||
def _get_len_in_samples(self, len_in_ms: float) -> int:
|
||||
_len_in_s = self._sampling_rate * len_in_ms / 1000
|
||||
assert _len_in_s.is_integer(), _len_in_s
|
||||
len_in_s = int(_len_in_s)
|
||||
|
||||
return len_in_s
|
||||
|
||||
def _allocate_new_buffer(self) -> None:
|
||||
# allocate new buffer
|
||||
new_buffer = np.empty(self._buffer_size, dtype=np.float32)
|
||||
left_to_copy = max(self._filled_buffer_len - self.start_idx, 0)
|
||||
|
||||
if left_to_copy > 0:
|
||||
new_buffer[:left_to_copy] = self._buffer[
|
||||
self.start_idx : self._filled_buffer_len
|
||||
]
|
||||
|
||||
del self._buffer
|
||||
self._buffer = new_buffer
|
||||
|
||||
self._filled_buffer_len = left_to_copy
|
||||
self._start = self._look_back
|
||||
self._end = self._start + self._streaming_size
|
||||
|
||||
def write_audio(self, audio: np.ndarray) -> None:
|
||||
put_end_idx = self._filled_buffer_len + len(audio)
|
||||
|
||||
if put_end_idx > self._buffer_size:
|
||||
self._allocate_new_buffer()
|
||||
|
||||
self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = (
|
||||
audio
|
||||
)
|
||||
self._filled_buffer_len += len(audio)
|
||||
|
||||
def read_audio(self) -> np.ndarray | None:
|
||||
if not self.is_audio_complete:
|
||||
return None
|
||||
|
||||
audio = self._buffer[self.start_idx : self.end_idx]
|
||||
self._start = self._end
|
||||
self._end += self._streaming_size
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
VoxtralStreamingMultiModalProcessor,
|
||||
info=VoxtralProcessingInfo,
|
||||
dummy_inputs=VoxtralDummyInputsBuilder,
|
||||
)
|
||||
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
||||
@support_torch_compile
|
||||
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration, SupportsRealtime):
|
||||
requires_raw_input_tokens = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
assert (
|
||||
not vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
), (
|
||||
"Voxtral streaming doesn't support full cudagraphs yet. "
|
||||
"Please use PIECEWISE."
|
||||
)
|
||||
|
||||
self.time_embedding: TimeEmbedding = TimeEmbedding(
|
||||
dim=self.config.text_config.hidden_size
|
||||
)
|
||||
|
||||
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
|
||||
_n_delay_tokens = (
|
||||
audio_config.frame_rate * audio_config.transcription_delay_ms / 1000
|
||||
)
|
||||
assert _n_delay_tokens.is_integer(), (
|
||||
f"n_delay_tokens must be integer, got {_n_delay_tokens}"
|
||||
)
|
||||
self.n_delay_tokens = audio_config.num_delay_tokens
|
||||
|
||||
self.n_delay_tokens = int(_n_delay_tokens)
|
||||
# for realtime transcription
|
||||
@classmethod
|
||||
async def buffer_realtime_audio(
|
||||
cls,
|
||||
audio_stream: AsyncGenerator[np.ndarray, None],
|
||||
input_stream: asyncio.Queue[list[int]],
|
||||
model_config: ModelConfig,
|
||||
) -> AsyncGenerator[PromptType, None]:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio_encoder = tokenizer.instruct.audio_encoder
|
||||
config = audio_encoder.audio_config
|
||||
|
||||
buffer = VoxtralRealtimeBuffer(config)
|
||||
is_first_yield = True
|
||||
|
||||
async for audio in audio_stream:
|
||||
buffer.write_audio(audio)
|
||||
|
||||
while (new_audio := buffer.read_audio()) is not None:
|
||||
if is_first_yield:
|
||||
# make sure that input_stream is empty
|
||||
assert input_stream.empty()
|
||||
|
||||
audio = Audio(new_audio, config.sampling_rate, format="wav")
|
||||
|
||||
request = TranscriptionRequest(
|
||||
streaming=StreamingMode.ONLINE,
|
||||
audio=RawAudio.from_audio(audio),
|
||||
language=None,
|
||||
)
|
||||
# mistral tokenizer takes care
|
||||
# of preparing the first prompt inputs
|
||||
# and does some left-silence padding
|
||||
# for improved performance
|
||||
audio_enc = tokenizer.mistral.encode_transcription(request)
|
||||
|
||||
token_ids = audio_enc.tokens
|
||||
new_audio = audio_enc.audios[0].audio_array
|
||||
|
||||
is_first_yield = False
|
||||
else:
|
||||
# pop last element from input_stream
|
||||
all_outputs = await asyncio.wait_for(
|
||||
input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
)
|
||||
token_ids = all_outputs[-1:]
|
||||
|
||||
multi_modal_data = {"audio": (new_audio, None)}
|
||||
yield TokensPrompt(
|
||||
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
|
||||
)
|
||||
|
||||
@property
|
||||
def audio_config(self):
|
||||
@@ -205,8 +345,9 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
||||
# sum pool text and audio embeddings
|
||||
inputs_embeds = audio_text_embeds + text_embeds
|
||||
|
||||
time_tensor = torch.tensor(
|
||||
[self.n_delay_tokens],
|
||||
time_tensor = torch.full(
|
||||
(1,),
|
||||
fill_value=self.n_delay_tokens,
|
||||
device=inputs_embeds.device,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user