[Voxtral Realtime] Refactor & Improve buffering logic (#34428)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
1100a97621
commit
6c0baee610
@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
|
|||||||
pyzmq >= 25.0.0
|
pyzmq >= 25.0.0
|
||||||
msgspec
|
msgspec
|
||||||
gguf >= 0.17.0
|
gguf >= 0.17.0
|
||||||
mistral_common[image] >= 1.9.0
|
mistral_common[image] >= 1.9.1
|
||||||
opencv-python-headless >= 4.13.0 # required for video IO
|
opencv-python-headless >= 4.13.0 # required for video IO
|
||||||
pyyaml
|
pyyaml
|
||||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||||
@@ -52,4 +52,4 @@ anthropic >= 0.71.0
|
|||||||
model-hosting-container-standards >= 0.1.13, < 1.0.0
|
model-hosting-container-standards >= 0.1.13, < 1.0.0
|
||||||
mcp
|
mcp
|
||||||
grpcio
|
grpcio
|
||||||
grpcio-reflection
|
grpcio-reflection
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ jiwer # required for audio tests
|
|||||||
timm # required for internvl test
|
timm # required for internvl test
|
||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[image,audio] >= 1.9.0 # required for voxtral test
|
mistral_common[image,audio] >= 1.9.1 # required for voxtral test
|
||||||
num2words # required for smolvlm test
|
num2words # required for smolvlm test
|
||||||
opencv-python-headless >= 4.13.0 # required for video test
|
opencv-python-headless >= 4.13.0 # required for video test
|
||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ torchaudio==2.10.0
|
|||||||
torchvision==0.25.0
|
torchvision==0.25.0
|
||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[image,audio] >= 1.9.0 # required for voxtral test
|
mistral_common[image,audio] >= 1.9.1 # required for voxtral test
|
||||||
num2words # required for smolvlm test
|
num2words # required for smolvlm test
|
||||||
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
|
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
|
||||||
opencv-python-headless >= 4.13.0 # required for video test
|
opencv-python-headless >= 4.13.0 # required for video test
|
||||||
|
|||||||
@@ -499,7 +499,7 @@ mbstrdecoder==1.1.3
|
|||||||
# typepy
|
# typepy
|
||||||
mdurl==0.1.2
|
mdurl==0.1.2
|
||||||
# via markdown-it-py
|
# via markdown-it-py
|
||||||
mistral-common==1.9.0
|
mistral-common==1.9.1
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
mlflow==2.22.0
|
mlflow==2.22.0
|
||||||
# via terratorch
|
# via terratorch
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import asyncio
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -10,14 +9,13 @@ from mistral_common.protocol.transcription.request import (
|
|||||||
StreamingMode,
|
StreamingMode,
|
||||||
TranscriptionRequest,
|
TranscriptionRequest,
|
||||||
)
|
)
|
||||||
from mistral_common.tokens.tokenizers.audio import AudioConfig
|
|
||||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
from vllm import LLM, EngineArgs, SamplingParams
|
from vllm import LLM, EngineArgs, SamplingParams
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.inputs.data import TokensPrompt
|
from vllm.model_executor.models.voxtral_realtime import VoxtralRealtimeBuffer
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
|
||||||
MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
|
MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
|
||||||
ENGINE_CONFIG = dict(
|
ENGINE_CONFIG = dict(
|
||||||
@@ -114,136 +112,40 @@ def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
|
|||||||
assert texts == EXPECTED_TEXT
|
assert texts == EXPECTED_TEXT
|
||||||
|
|
||||||
|
|
||||||
class RealTimeAudioInput:
|
|
||||||
"""
|
|
||||||
This class is used to stream an audio file just as
|
|
||||||
if it would be streamed in real-time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
|
||||||
self._tokenizer = tokenizer
|
|
||||||
self._config: AudioConfig = (
|
|
||||||
self._tokenizer.instruct_tokenizer.audio_encoder.audio_config
|
|
||||||
)
|
|
||||||
|
|
||||||
self._look_ahead_in_ms = self._config.streaming_look_ahead_ms
|
|
||||||
self._look_back_in_ms = self._config.streaming_look_back_ms
|
|
||||||
|
|
||||||
self._sampling_rate = self._config.sampling_rate
|
|
||||||
|
|
||||||
self._audio: Audio | None = None
|
|
||||||
|
|
||||||
# mutable objects
|
|
||||||
self._start = 0
|
|
||||||
|
|
||||||
n_left_pad_samples = (
|
|
||||||
self._config.raw_audio_length_per_tok * self._config.n_left_pad_tokens
|
|
||||||
)
|
|
||||||
self._end = self.streaming_delay + n_left_pad_samples + self.streaming_size
|
|
||||||
self._queue: asyncio.Queue[StreamingInput | None] = asyncio.Queue()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(cls, audio: Audio, tokenizer: MistralTokenizer):
|
|
||||||
self = cls(tokenizer)
|
|
||||||
|
|
||||||
# we're doing "OFFLINE" encoding here to right & left pad the audio since
|
|
||||||
# we have access to the whole audio
|
|
||||||
# if we'd do an actual online realtime streaming application we
|
|
||||||
# should instead pass `StreamingMode.ONLINE`
|
|
||||||
req = TranscriptionRequest(
|
|
||||||
streaming=StreamingMode.OFFLINE,
|
|
||||||
audio=RawAudio.from_audio(audio),
|
|
||||||
language=None,
|
|
||||||
)
|
|
||||||
audio_enc = self._tokenizer.encode_transcription(req)
|
|
||||||
self._audio = audio_enc.audios[0]
|
|
||||||
|
|
||||||
# add first request
|
|
||||||
await self.add_tokens(audio_enc.tokens)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def look_ahead(self) -> int:
|
|
||||||
return self._get_len_in_samples(self._look_ahead_in_ms)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def look_back(self) -> int:
|
|
||||||
return self._get_len_in_samples(self._look_back_in_ms)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def streaming_delay(self) -> int:
|
|
||||||
return self._get_len_in_samples(self._config.transcription_delay_ms)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def streaming_size(self) -> int:
|
|
||||||
stream_size_in_ms = 1000 / self._config.frame_rate
|
|
||||||
return self._get_len_in_samples(stream_size_in_ms)
|
|
||||||
|
|
||||||
def _get_len_in_samples(self, len_in_ms: float) -> int:
|
|
||||||
_len_in_s = self._sampling_rate * len_in_ms / 1000
|
|
||||||
assert _len_in_s.is_integer(), _len_in_s
|
|
||||||
len_in_s = int(_len_in_s)
|
|
||||||
|
|
||||||
return len_in_s
|
|
||||||
|
|
||||||
async def add_tokens(self, tokens: list[int]) -> None:
|
|
||||||
assert self._audio is not None
|
|
||||||
if self._start >= len(self._audio.audio_array):
|
|
||||||
self.stop()
|
|
||||||
return
|
|
||||||
|
|
||||||
_end = self._end + self.look_ahead
|
|
||||||
_start = max(0, self._start - self.look_back)
|
|
||||||
|
|
||||||
multi_modal_data = {"audio": (self._audio.audio_array[_start:_end], None)}
|
|
||||||
|
|
||||||
prompt = TokensPrompt(
|
|
||||||
prompt_token_ids=tokens, multi_modal_data=multi_modal_data
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._queue.put(StreamingInput(prompt))
|
|
||||||
|
|
||||||
# increase
|
|
||||||
self._start = self._end
|
|
||||||
self._end = self._end + self.streaming_size
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self._queue.put_nowait(None)
|
|
||||||
|
|
||||||
async def generator(self):
|
|
||||||
while (item := await self._queue.get()) is not None:
|
|
||||||
yield item
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine):
|
async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine):
|
||||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=1)
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=1)
|
||||||
|
audio_config = tokenizer.instruct_tokenizer.audio_encoder.audio_config
|
||||||
|
|
||||||
output_tokens_list = []
|
output_tokens_list = []
|
||||||
for i, audio_asset in enumerate(audio_assets):
|
for i, audio_asset in enumerate(audio_assets):
|
||||||
output_tokens = []
|
output_tokens = []
|
||||||
audio = Audio.from_file(audio_asset.get_local_path(), strict=False)
|
audio = Audio.from_file(audio_asset.get_local_path(), strict=False)
|
||||||
streaming_input = await RealTimeAudioInput.create(
|
|
||||||
audio=audio, tokenizer=tokenizer
|
req = TranscriptionRequest(
|
||||||
|
streaming=StreamingMode.OFFLINE,
|
||||||
|
audio=RawAudio.from_audio(audio),
|
||||||
|
language=None,
|
||||||
)
|
)
|
||||||
|
audio_enc = tokenizer.encode_transcription(req)
|
||||||
|
|
||||||
|
buffer = VoxtralRealtimeBuffer(audio_config, audio_enc.tokens)
|
||||||
|
await buffer.append_audio(audio_enc.audios[0].audio_array)
|
||||||
|
await buffer.append_audio(None)
|
||||||
|
|
||||||
request_id = f"session-{i}"
|
request_id = f"session-{i}"
|
||||||
|
|
||||||
async for resp in async_engine.generate(
|
async for resp in async_engine.generate(
|
||||||
prompt=streaming_input.generator(),
|
prompt=buffer.get_input_stream(),
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
):
|
):
|
||||||
tokens = resp.outputs[0].token_ids[-1:]
|
tokens = resp.outputs[0].token_ids[-1:]
|
||||||
|
|
||||||
output_tokens.extend(tokens)
|
output_tokens.extend(tokens)
|
||||||
await streaming_input.add_tokens(tokens)
|
await buffer.append_tokens(tokens)
|
||||||
|
|
||||||
output_tokens_list.append(output_tokens)
|
output_tokens_list.append(output_tokens)
|
||||||
|
|
||||||
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]
|
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]
|
||||||
|
|
||||||
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
|
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
|
||||||
|
|
||||||
assert texts == EXPECTED_TEXT
|
assert texts == EXPECTED_TEXT
|
||||||
|
|||||||
@@ -155,9 +155,7 @@ class VoxtralProcessorAdapter:
|
|||||||
assert audio.ndim == 1
|
assert audio.ndim == 1
|
||||||
|
|
||||||
if not self._audio_processor.audio_config.is_streaming:
|
if not self._audio_processor.audio_config.is_streaming:
|
||||||
audio = self._audio_processor.pad(
|
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||||
audio, self.sampling_rate, is_online_streaming=False
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_tokens = [self.begin_audio_token_id] + [
|
audio_tokens = [self.begin_audio_token_id] + [
|
||||||
self.audio_token_id
|
self.audio_token_id
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
from collections.abc import AsyncGenerator, Mapping
|
from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -18,7 +18,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
|
|||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||||
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
|
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||||
from vllm.inputs.data import PromptType, TokensPrompt
|
from vllm.inputs.data import PromptType, StreamingInput, TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
|
||||||
from vllm.model_executor.models.voxtral import (
|
from vllm.model_executor.models.voxtral import (
|
||||||
@@ -47,8 +47,6 @@ from .utils import (
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30
|
|
||||||
|
|
||||||
|
|
||||||
class VoxtralRealtimeMultiModalProcessor(VoxtralMultiModalProcessor):
|
class VoxtralRealtimeMultiModalProcessor(VoxtralMultiModalProcessor):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -130,84 +128,81 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
class VoxtralRealtimeBuffer:
|
class VoxtralRealtimeBuffer:
|
||||||
def __init__(self, config: AudioConfig) -> None:
|
def __init__(self, config: AudioConfig, prompt_tokens: list[int]) -> None:
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._look_ahead_in_ms = config.streaming_look_ahead_ms
|
_look_ahead_in_ms = self._config.streaming_look_ahead_ms
|
||||||
self._look_back_in_ms = config.streaming_look_back_ms
|
_look_back_in_ms = self._config.streaming_look_back_ms
|
||||||
|
self._look_ahead_in_samples = self._ms_to_samples(_look_ahead_in_ms)
|
||||||
|
self._look_back_in_samples = self._ms_to_samples(_look_back_in_ms)
|
||||||
|
|
||||||
self._sampling_rate = self._config.sampling_rate
|
# None signals the end
|
||||||
|
self._audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
|
||||||
|
self._leftover: np.ndarray | None = None
|
||||||
|
self._token_queue: asyncio.Queue[int] = asyncio.Queue()
|
||||||
|
|
||||||
self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms)
|
self._initial_end = len(prompt_tokens) * self._config.raw_audio_length_per_tok
|
||||||
self._look_back = self._get_len_in_samples(self._look_back_in_ms)
|
for token in prompt_tokens:
|
||||||
self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate)
|
self._token_queue.put_nowait(token)
|
||||||
|
|
||||||
# mutable objects
|
def _generate_frame_size_and_num_tokens(self) -> Iterator[tuple[int, int]]:
|
||||||
streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms)
|
streaming_step_size = self._ms_to_samples(1000 / self._config.frame_rate)
|
||||||
self._start = 0
|
start = 0
|
||||||
self._end = streaming_delay + self._streaming_size
|
end = self._initial_end
|
||||||
|
while True:
|
||||||
|
frame_start = max(start - self._look_back_in_samples, 0)
|
||||||
|
frame_end = end + self._look_ahead_in_samples
|
||||||
|
frame_size = frame_end - frame_start
|
||||||
|
num_tokens = (end - start) / self._config.raw_audio_length_per_tok
|
||||||
|
assert num_tokens.is_integer()
|
||||||
|
yield frame_size, int(num_tokens)
|
||||||
|
start = end
|
||||||
|
end += streaming_step_size
|
||||||
|
|
||||||
# always pre-allocate 30 second buffers
|
def _ms_to_samples(self, ms: float) -> int:
|
||||||
self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate
|
len_ = self._config.sampling_rate * ms / 1000
|
||||||
self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32)
|
assert len_.is_integer(), len_
|
||||||
self._filled_buffer_len = 0
|
return int(len_)
|
||||||
|
|
||||||
@property
|
async def append_audio(self, audio_array: np.ndarray | None) -> None:
|
||||||
def start_idx(self):
|
await self._audio_queue.put(audio_array)
|
||||||
return max(self._start - self._look_back, 0)
|
|
||||||
|
|
||||||
@property
|
async def append_tokens(self, tokens: Iterable[int]) -> None:
|
||||||
def end_idx(self):
|
for token in tokens:
|
||||||
return self._end + self._look_ahead
|
await self._token_queue.put(token)
|
||||||
|
|
||||||
@property
|
async def get_input_stream(self) -> AsyncGenerator[StreamingInput]:
|
||||||
def is_audio_complete(self) -> bool:
|
for frame_size, num_tokens in self._generate_frame_size_and_num_tokens():
|
||||||
return self._filled_buffer_len >= self.end_idx
|
next_tokens = [await self._token_queue.get() for _ in range(num_tokens)]
|
||||||
|
|
||||||
def _get_len_in_samples(self, len_in_ms: float) -> int:
|
audio_arrays: list[np.ndarray] = (
|
||||||
_len_in_s = self._sampling_rate * len_in_ms / 1000
|
[self._leftover] if self._leftover is not None else []
|
||||||
assert _len_in_s.is_integer(), _len_in_s
|
)
|
||||||
len_in_s = int(_len_in_s)
|
while sum(len(arr) for arr in audio_arrays) < frame_size:
|
||||||
|
arr = await self._audio_queue.get()
|
||||||
|
if arr is None:
|
||||||
|
return
|
||||||
|
audio_arrays.append(arr)
|
||||||
|
|
||||||
return len_in_s
|
audio_array = np.concatenate(audio_arrays)
|
||||||
|
frame = audio_array[:frame_size]
|
||||||
|
|
||||||
def _allocate_new_buffer(self) -> None:
|
# The current stride took look_ahead_in_samples audio of the next sample
|
||||||
# allocate new buffer
|
# In addition the next sample will take look_back_in_samples audio of
|
||||||
new_buffer = np.empty(self._buffer_size, dtype=np.float32)
|
# the current sample => So let's put both of this into the leftover
|
||||||
left_to_copy = max(self._filled_buffer_len - self.start_idx, 0)
|
stride = (
|
||||||
|
frame_size - self._look_ahead_in_samples - self._look_back_in_samples
|
||||||
|
)
|
||||||
|
assert stride > 0, f"{stride=} must be positive"
|
||||||
|
|
||||||
if left_to_copy > 0:
|
self._leftover = audio_array[stride:]
|
||||||
new_buffer[:left_to_copy] = self._buffer[
|
|
||||||
self.start_idx : self._filled_buffer_len
|
|
||||||
]
|
|
||||||
|
|
||||||
del self._buffer
|
yield StreamingInput(
|
||||||
self._buffer = new_buffer
|
TokensPrompt(
|
||||||
|
prompt_token_ids=next_tokens,
|
||||||
self._filled_buffer_len = left_to_copy
|
multi_modal_data={"audio": (frame, None)},
|
||||||
self._start = self._look_back
|
)
|
||||||
self._end = self._start + self._streaming_size
|
)
|
||||||
|
|
||||||
def write_audio(self, audio: np.ndarray) -> None:
|
|
||||||
put_end_idx = self._filled_buffer_len + len(audio)
|
|
||||||
|
|
||||||
if put_end_idx > self._buffer_size:
|
|
||||||
self._allocate_new_buffer()
|
|
||||||
|
|
||||||
self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = (
|
|
||||||
audio
|
|
||||||
)
|
|
||||||
self._filled_buffer_len += len(audio)
|
|
||||||
|
|
||||||
def read_audio(self) -> np.ndarray | None:
|
|
||||||
if not self.is_audio_complete:
|
|
||||||
return None
|
|
||||||
|
|
||||||
audio = self._buffer[self.start_idx : self.end_idx]
|
|
||||||
self._start = self._end
|
|
||||||
self._end += self._streaming_size
|
|
||||||
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
@@ -234,7 +229,7 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
|
|||||||
)
|
)
|
||||||
|
|
||||||
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
|
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
|
||||||
self.n_delay_tokens = audio_config.num_delay_tokens
|
self.n_delay_tokens = audio_config.get_num_delay_tokens()
|
||||||
|
|
||||||
# for realtime transcription
|
# for realtime transcription
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -248,45 +243,47 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
|
|||||||
audio_encoder = tokenizer.instruct.audio_encoder
|
audio_encoder = tokenizer.instruct.audio_encoder
|
||||||
config = audio_encoder.audio_config
|
config = audio_encoder.audio_config
|
||||||
|
|
||||||
buffer = VoxtralRealtimeBuffer(config)
|
# Get prompt tokens (streaming prefix tokens) without encoding audio
|
||||||
is_first_yield = True
|
prompt_tokens = (
|
||||||
|
tokenizer.instruct.start() + audio_encoder.encode_streaming_tokens()
|
||||||
|
)
|
||||||
|
|
||||||
async for audio in audio_stream:
|
# Get left/right padding audio
|
||||||
buffer.write_audio(audio)
|
left_pad, right_pad = audio_encoder.get_padding_audio()
|
||||||
|
|
||||||
while (new_audio := buffer.read_audio()) is not None:
|
buffer = VoxtralRealtimeBuffer(config, prompt_tokens)
|
||||||
if is_first_yield:
|
|
||||||
# make sure that input_stream is empty
|
|
||||||
assert input_stream.empty()
|
|
||||||
|
|
||||||
audio = Audio(new_audio, config.sampling_rate, format="wav")
|
# Feed audio with padding into buffer in background
|
||||||
|
async def feed_audio():
|
||||||
|
yielded_first_chunk = False
|
||||||
|
async for audio_chunk in audio_stream:
|
||||||
|
if not yielded_first_chunk:
|
||||||
|
yielded_first_chunk = True
|
||||||
|
# Prepend left padding before first real audio
|
||||||
|
await buffer.append_audio(left_pad.audio_array)
|
||||||
|
await buffer.append_audio(audio_chunk)
|
||||||
|
# Append right padding at the end
|
||||||
|
await buffer.append_audio(right_pad.audio_array)
|
||||||
|
await buffer.append_audio(None) # signal end
|
||||||
|
|
||||||
request = TranscriptionRequest(
|
# Feed output tokens back into buffer in background
|
||||||
streaming=StreamingMode.ONLINE,
|
async def feed_tokens():
|
||||||
audio=RawAudio.from_audio(audio),
|
while True:
|
||||||
language=None,
|
all_outputs = await asyncio.wait_for(
|
||||||
)
|
input_stream.get(),
|
||||||
# mistral tokenizer takes care
|
timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S,
|
||||||
# of preparing the first prompt inputs
|
|
||||||
# and does some left-silence padding
|
|
||||||
# for improved performance
|
|
||||||
audio_enc = tokenizer.mistral.encode_transcription(request)
|
|
||||||
|
|
||||||
token_ids = audio_enc.tokens
|
|
||||||
new_audio = audio_enc.audios[0].audio_array
|
|
||||||
|
|
||||||
is_first_yield = False
|
|
||||||
else:
|
|
||||||
# pop last element from input_stream
|
|
||||||
all_outputs = await asyncio.wait_for(
|
|
||||||
input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S
|
|
||||||
)
|
|
||||||
token_ids = all_outputs[-1:]
|
|
||||||
|
|
||||||
multi_modal_data = {"audio": (new_audio, None)}
|
|
||||||
yield TokensPrompt(
|
|
||||||
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
|
|
||||||
)
|
)
|
||||||
|
await buffer.append_tokens(all_outputs[-1:])
|
||||||
|
|
||||||
|
audio_task = asyncio.create_task(feed_audio())
|
||||||
|
token_task = asyncio.create_task(feed_tokens())
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for streaming_input in buffer.get_input_stream():
|
||||||
|
yield streaming_input.prompt
|
||||||
|
finally:
|
||||||
|
audio_task.cancel()
|
||||||
|
token_task.cancel()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_config(self):
|
def audio_config(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user