diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 61d827ecc..a7b1c18a6 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -59,6 +59,8 @@ We currently support the following OpenAI APIs: - Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription). - [Translation API](#translations-api) (`/v1/audio/translations`) - Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription). +- [Realtime API](#realtime-api) (`/v1/realtime`) + - Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription). In addition, we have the following custom APIs: @@ -567,6 +569,96 @@ The following extra parameters are supported: --8<-- "vllm/entrypoints/openai/protocol.py:translation-extra-params" ``` +### Realtime API + +The Realtime API provides WebSocket-based streaming audio transcription, allowing real-time speech-to-text as audio is being recorded. + +!!! note + To use the Realtime API, please install with extra audio dependencies using `uv pip install vllm[audio]`. + +#### Audio Format + +Audio must be sent as base64-encoded PCM16 audio at 16kHz sample rate, mono channel. + +#### Protocol Overview + +1. Client connects to `ws://host/v1/realtime` +2. Server sends `session.created` event +3. Client optionally sends `session.update` with model/params +4. Client sends `input_audio_buffer.commit` when ready +5. Client sends `input_audio_buffer.append` events with base64 PCM16 chunks +6. Server sends `transcription.delta` events with incremental text +7. Server sends `transcription.done` with final text + usage +8. Repeat from step 5 for next utterance +9. Optionally, client sends input_audio_buffer.commit with final=True + to signal audio input is finished. Useful when streaming audio files + +#### Client → Server Events + +| Event | Description | +|-------|-------------| +| `input_audio_buffer.append` | Send base64-encoded audio chunk: `{"type": "input_audio_buffer.append", "audio": ""}` | +| `input_audio_buffer.commit` | Trigger transcription processing or end: `{"type": "input_audio_buffer.commit", "final": bool}` | +| `session.update` | Configure session: `{"type": "session.update", "model": "model-name"}` | + +#### Server → Client Events + +| Event | Description | +|-------|-------------| +| `session.created` | Connection established with session ID and timestamp | +| `transcription.delta` | Incremental transcription text: `{"type": "transcription.delta", "delta": "text"}` | +| `transcription.done` | Final transcription with usage stats | +| `error` | Error notification with message and optional code | + +#### Python WebSocket Example + +??? code + + ```python + import asyncio + import base64 + import json + import websockets + + async def realtime_transcribe(): + uri = "ws://localhost:8000/v1/realtime" + + async with websockets.connect(uri) as ws: + # Wait for session.created + response = await ws.recv() + print(f"Session: {response}") + + # Commit buffer + await ws.send(json.dumps({ + "type": "input_audio_buffer.commit" + })) + + # Send audio chunks (example with file) + with open("audio.raw", "rb") as f: + while chunk := f.read(4096): + await ws.send(json.dumps({ + "type": "input_audio_buffer.append", + "audio": base64.b64encode(chunk).decode() + })) + + # Signal all audio is sent + await ws.send(json.dumps({ + "type": "input_audio_buffer.commit", + "final": True, + })) + + # Receive transcription + while True: + response = json.loads(await ws.recv()) + if response["type"] == "transcription.delta": + print(response["delta"], end="", flush=True) + elif response["type"] == "transcription.done": + print(f"\nFinal: {response['text']}") + break + + asyncio.run(realtime_transcribe()) + ``` + ### Tokenizer API Our Tokenizer API is a simple wrapper over [HuggingFace-style tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer). diff --git a/examples/online_serving/openai_realtime_client.py b/examples/online_serving/openai_realtime_client.py new file mode 100644 index 000000000..5aa31e3e5 --- /dev/null +++ b/examples/online_serving/openai_realtime_client.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This script demonstrates how to use the vLLM Realtime WebSocket API to perform +audio transcription by uploading an audio file. + +Before running this script, you must start the vLLM server with a realtime-capable +model, for example: + + vllm serve mistralai/Voxtral-Mini-3B-Realtime-2602 --enforce-eager + +Requirements: +- vllm with audio support +- websockets +- librosa +- numpy + +The script: +1. Connects to the Realtime WebSocket endpoint +2. Converts an audio file to PCM16 @ 16kHz +3. Sends audio chunks to the server +4. Receives and prints transcription as it streams +""" + +import argparse +import asyncio +import base64 +import json + +import librosa +import numpy as np +import websockets + +from vllm.assets.audio import AudioAsset + + +def audio_to_pcm16_base64(audio_path: str) -> str: + """ + Load an audio file and convert it to base64-encoded PCM16 @ 16kHz. + """ + # Load audio and resample to 16kHz mono + audio, _ = librosa.load(audio_path, sr=16000, mono=True) + # Convert to PCM16 + pcm16 = (audio * 32767).astype(np.int16) + # Encode as base64 + return base64.b64encode(pcm16.tobytes()).decode("utf-8") + + +async def realtime_transcribe(audio_path: str, host: str, port: int, model: str): + """ + Connect to the Realtime API and transcribe an audio file. + """ + uri = f"ws://{host}:{port}/v1/realtime" + + async with websockets.connect(uri) as ws: + # Wait for session.created + response = json.loads(await ws.recv()) + if response["type"] == "session.created": + print(f"Session created: {response['id']}") + else: + print(f"Unexpected response: {response}") + return + + # Validate model + await ws.send(json.dumps({"type": "session.update", "model": model})) + + # Signal ready to start + await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) + + # Convert audio file to base64 PCM16 + print(f"Loading audio from: {audio_path}") + audio_base64 = audio_to_pcm16_base64(audio_path) + + # Send audio in chunks (4KB of raw audio = ~8KB base64) + chunk_size = 4096 + audio_bytes = base64.b64decode(audio_base64) + total_chunks = (len(audio_bytes) + chunk_size - 1) // chunk_size + + print(f"Sending {total_chunks} audio chunks...") + for i in range(0, len(audio_bytes), chunk_size): + chunk = audio_bytes[i : i + chunk_size] + await ws.send( + json.dumps( + { + "type": "input_audio_buffer.append", + "audio": base64.b64encode(chunk).decode("utf-8"), + } + ) + ) + + # Signal all audio is sent + await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True})) + print("Audio sent. Waiting for transcription...\n") + + # Receive transcription + print("Transcription: ", end="", flush=True) + while True: + response = json.loads(await ws.recv()) + if response["type"] == "transcription.delta": + print(response["delta"], end="", flush=True) + elif response["type"] == "transcription.done": + print(f"\n\nFinal transcription: {response['text']}") + if response.get("usage"): + print(f"Usage: {response['usage']}") + break + elif response["type"] == "error": + print(f"\nError: {response['error']}") + break + + +def main(args): + if args.audio_path: + audio_path = args.audio_path + else: + # Use default audio asset + audio_path = str(AudioAsset("mary_had_lamb").get_local_path()) + print(f"No audio path provided, using default: {audio_path}") + + asyncio.run(realtime_transcribe(audio_path, args.host, args.port, args.model)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Realtime WebSocket Transcription Client" + ) + parser.add_argument( + "--model", + type=str, + default="mistralai/Voxtral-Mini-3B-Realtime-2602", + help="Model that is served and should be pinged.", + ) + parser.add_argument( + "--audio_path", + type=str, + default=None, + help="Path to the audio file to transcribe.", + ) + parser.add_argument( + "--host", + type=str, + default="localhost", + help="vLLM server host (default: localhost)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="vLLM server port (default: 8000)", + ) + args = parser.parse_args() + main(args) diff --git a/examples/online_serving/openai_realtime_microphone_client.py b/examples/online_serving/openai_realtime_microphone_client.py new file mode 100644 index 000000000..fc80b1c50 --- /dev/null +++ b/examples/online_serving/openai_realtime_microphone_client.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Minimal Gradio demo for real-time speech transcription using the vLLM Realtime API. + +Start the vLLM server first: + + vllm serve mistralai/Voxtral-Mini-3B-Realtime-2602 --enforce-eager + +Then run this script: + + python openai_realtime_microphone_client.py --host localhost --port 8000 + +Use --share to create a public Gradio link. + +Requirements: websockets, numpy, gradio +""" + +import argparse +import asyncio +import base64 +import json +import queue +import threading + +import gradio as gr +import numpy as np +import websockets + +SAMPLE_RATE = 16_000 + +# Global state +audio_queue: queue.Queue = queue.Queue() +transcription_text = "" +is_running = False +ws_url = "" +model = "" + + +async def websocket_handler(): + """Connect to WebSocket and handle audio streaming + transcription.""" + global transcription_text, is_running + + async with websockets.connect(ws_url) as ws: + # Wait for session.created + await ws.recv() + + # Validate model + await ws.send(json.dumps({"type": "session.update", "model": model})) + + # Signal ready + await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) + + async def send_audio(): + while is_running: + try: + chunk = await asyncio.get_event_loop().run_in_executor( + None, lambda: audio_queue.get(timeout=0.1) + ) + await ws.send( + json.dumps( + {"type": "input_audio_buffer.append", "audio": chunk} + ) + ) + except queue.Empty: + continue + + async def receive_transcription(): + global transcription_text + async for message in ws: + data = json.loads(message) + if data.get("type") == "transcription.delta": + transcription_text += data["delta"] + + await asyncio.gather(send_audio(), receive_transcription()) + + +def start_websocket(): + """Start WebSocket connection in background thread.""" + global is_running + is_running = True + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(websocket_handler()) + except Exception as e: + print(f"WebSocket error: {e}") + + +def start_recording(): + """Start the transcription service.""" + global transcription_text + transcription_text = "" + thread = threading.Thread(target=start_websocket, daemon=True) + thread.start() + return gr.update(interactive=False), gr.update(interactive=True), "" + + +def stop_recording(): + """Stop the transcription service.""" + global is_running + is_running = False + return gr.update(interactive=True), gr.update(interactive=False), transcription_text + + +def process_audio(audio): + """Process incoming audio and queue for streaming.""" + global transcription_text + + if audio is None or not is_running: + return transcription_text + + sample_rate, audio_data = audio + + # Convert to mono if stereo + if len(audio_data.shape) > 1: + audio_data = audio_data.mean(axis=1) + + # Normalize to float + if audio_data.dtype == np.int16: + audio_float = audio_data.astype(np.float32) / 32767.0 + else: + audio_float = audio_data.astype(np.float32) + + # Resample to 16kHz if needed + if sample_rate != SAMPLE_RATE: + num_samples = int(len(audio_float) * SAMPLE_RATE / sample_rate) + audio_float = np.interp( + np.linspace(0, len(audio_float) - 1, num_samples), + np.arange(len(audio_float)), + audio_float, + ) + + # Convert to PCM16 and base64 encode + pcm16 = (audio_float * 32767).astype(np.int16) + b64_chunk = base64.b64encode(pcm16.tobytes()).decode("utf-8") + audio_queue.put(b64_chunk) + + return transcription_text + + +# Gradio interface +with gr.Blocks(title="Real-time Speech Transcription") as demo: + gr.Markdown("# Real-time Speech Transcription") + gr.Markdown("Click **Start** and speak into your microphone.") + + with gr.Row(): + start_btn = gr.Button("Start", variant="primary") + stop_btn = gr.Button("Stop", variant="stop", interactive=False) + + audio_input = gr.Audio(sources=["microphone"], streaming=True, type="numpy") + transcription_output = gr.Textbox(label="Transcription", lines=5) + + start_btn.click( + start_recording, outputs=[start_btn, stop_btn, transcription_output] + ) + stop_btn.click(stop_recording, outputs=[start_btn, stop_btn, transcription_output]) + audio_input.stream( + process_audio, inputs=[audio_input], outputs=[transcription_output] + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Realtime WebSocket Transcription with Gradio" + ) + parser.add_argument( + "--model", + type=str, + default="mistralai/Voxtral-Mini-3B-Realtime-2602", + help="Model that is served and should be pinged.", + ) + parser.add_argument( + "--host", type=str, default="localhost", help="vLLM server host" + ) + parser.add_argument("--port", type=int, default=8000, help="vLLM server port") + parser.add_argument( + "--share", action="store_true", help="Create public Gradio link" + ) + args = parser.parse_args() + + ws_url = f"ws://{args.host}:{args.port}/v1/realtime" + model = args.model + demo.launch(share=args.share) diff --git a/tests/entrypoints/openai/test_realtime_validation.py b/tests/entrypoints/openai/test_realtime_validation.py new file mode 100644 index 000000000..d0a37cd5e --- /dev/null +++ b/tests/entrypoints/openai/test_realtime_validation.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import base64 +import json + +import librosa +import numpy as np +import pytest +import websockets + +from vllm.assets.audio import AudioAsset + +from ...utils import RemoteOpenAIServer +from .conftest import add_attention_backend + +MISTRAL_FORMAT_ARGS = [ + "--tokenizer_mode", + "mistral", + "--config_format", + "mistral", + "--load_format", + "mistral", +] + +MODEL_NAME = "mistralai/Voxtral-Mini-3B-Realtime-2602" + + +def _audio_to_base64_pcm16(path: str, target_sr: int = 16000) -> str: + """Load audio file, convert to PCM16 @ target sample rate, base64 encode.""" + audio, _ = librosa.load(path, sr=target_sr, mono=True) + # Convert float32 [-1, 1] to int16 [-32768, 32767] + audio_int16 = (audio * 32767).astype(np.int16) + audio_bytes = audio_int16.tobytes() + return base64.b64encode(audio_bytes).decode("utf-8") + + +def _get_websocket_url(server: RemoteOpenAIServer) -> str: + """Convert HTTP URL to WebSocket URL for realtime endpoint.""" + http_url = server.url_root + ws_url = http_url.replace("http://", "ws://") + return f"{ws_url}/v1/realtime" + + +async def receive_event(ws, timeout: float = 60.0) -> dict: + """Receive and parse JSON event from WebSocket.""" + message = await asyncio.wait_for(ws.recv(), timeout=timeout) + return json.loads(message) + + +async def send_event(ws, event: dict) -> None: + """Send JSON event to WebSocket.""" + await ws.send(json.dumps(event)) + + +@pytest.fixture +def mary_had_lamb_audio_chunks() -> list[str]: + """Audio split into ~1 second chunks for streaming.""" + path = AudioAsset("mary_had_lamb").get_local_path() + audio, _ = librosa.load(str(path), sr=16000, mono=True) + + # Split into ~0.1 second chunks (1600 samples at 16kHz) + chunk_size = 1600 + chunks = [] + for i in range(0, len(audio), chunk_size): + chunk = audio[i : i + chunk_size] + chunk_int16 = (chunk * 32767).astype(np.int16) + chunk_bytes = chunk_int16.tobytes() + chunks.append(base64.b64encode(chunk_bytes).decode("utf-8")) + + return chunks + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="Voxtral streaming is not yet public") +async def test_multi_chunk_streaming( + model_name, mary_had_lamb_audio_chunks, rocm_aiter_fa_attention +): + """Test streaming multiple audio chunks before committing.""" + server_args = ["--enforce-eager"] + + if model_name.startswith("mistralai"): + server_args += MISTRAL_FORMAT_ARGS + + add_attention_backend(server_args, rocm_aiter_fa_attention) + + with RemoteOpenAIServer(model_name, server_args) as remote_server: + ws_url = _get_websocket_url(remote_server) + async with websockets.connect(ws_url) as ws: + # Receive session.created + event = await receive_event(ws, timeout=30.0) + assert event["type"] == "session.created" + + await send_event(ws, {"type": "session.update", "model": model_name}) + + # Send commit to start transcription + await send_event(ws, {"type": "input_audio_buffer.commit"}) + + # Send multiple audio chunks + for chunk in mary_had_lamb_audio_chunks: + await send_event( + ws, {"type": "input_audio_buffer.append", "audio": chunk} + ) + + # Send commit to end + await send_event(ws, {"type": "input_audio_buffer.commit", "final": True}) + + # Collect transcription deltas + full_text = "" + done_received = False + + while not done_received: + event = await receive_event(ws, timeout=60.0) + + if event["type"] == "transcription.delta": + full_text += event["delta"] + elif event["type"] == "transcription.done": + done_received = True + assert "text" in event + elif event["type"] == "error": + pytest.fail(f"Received error: {event}") + + # Verify transcription contains expected content + assert event["type"] == "transcription.done" + assert event["text"] == full_text + assert full_text == ( + " He has first words I spoke in the original phonograph." + " A little piece of practical poetry. Mary had a little lamb," + " it squeaked with quite a flow, and everywhere that Mary went," + " the lamb was sure to go" + ) diff --git a/tests/v1/e2e/test_streaming_input.py b/tests/v1/e2e/test_streaming_input.py index 40bb30d9a..a1eaa065a 100644 --- a/tests/v1/e2e/test_streaming_input.py +++ b/tests/v1/e2e/test_streaming_input.py @@ -19,11 +19,12 @@ import pytest import pytest_asyncio from vllm import SamplingParams +from vllm.inputs.data import StreamingInput from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.utils.torch_utils import set_default_torch_num_threads -from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput +from vllm.v1.engine.async_llm import AsyncLLM if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) diff --git a/tests/v1/streaming_input/test_async_llm_streaming.py b/tests/v1/streaming_input/test_async_llm_streaming.py index 913576f70..992634387 100644 --- a/tests/v1/streaming_input/test_async_llm_streaming.py +++ b/tests/v1/streaming_input/test_async_llm_streaming.py @@ -7,9 +7,10 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from vllm.inputs.data import StreamingInput from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput +from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.output_processor import RequestOutputCollector diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 205efd1d5..c4248cf83 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Iterable, Mapping from typing import Any from vllm.config import ModelConfig, VllmConfig -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, StreamingInput from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import IOProcessor @@ -49,7 +49,7 @@ class EngineClient(ABC): @abstractmethod def generate( self, - prompt: EngineCoreRequest | PromptType, + prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], sampling_params: SamplingParams, request_id: str, *, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index cabf95e8d..e75d66bbf 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -36,6 +36,7 @@ async def serve_http( h11_max_header_count. """ logger.info("Available routes are:") + # post endpoints for route in app.routes: methods = getattr(route, "methods", None) path = getattr(route, "path", None) @@ -45,6 +46,17 @@ async def serve_http( logger.info("Route: %s, Methods: %s", path, ", ".join(methods)) + # other endpoints + for route in app.routes: + endpoint = getattr(route, "endpoint", None) + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if endpoint is None or path is None or methods is not None: + continue + + logger.info("Route: %s, Endpoint: %s", path, endpoint.__name__) + # Extract header limit options if present h11_max_incomplete_event_size = uvicorn_kwargs.pop( "h11_max_incomplete_event_size", None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9639ba28e..a1ee3607a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -196,6 +196,13 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> register_translations_api_router(app) + if "realtime" in supported_tasks: + from vllm.entrypoints.openai.realtime.api_router import ( + attach_router as register_realtime_api_router, + ) + + register_realtime_api_router(app) + if any(task in POOLING_TASKS for task in supported_tasks): from vllm.entrypoints.pooling import register_pooling_api_routers @@ -319,6 +326,11 @@ async def init_app_state( engine_client, state, args, request_logger, supported_tasks ) + if "realtime" in supported_tasks: + from vllm.entrypoints.openai.realtime.api_router import init_realtime_state + + init_realtime_state(engine_client, state, args, request_logger, supported_tasks) + if any(task in POOLING_TASKS for task in supported_tasks): from vllm.entrypoints.pooling import init_pooling_state diff --git a/vllm/entrypoints/openai/realtime/__init__.py b/vllm/entrypoints/openai/realtime/__init__.py new file mode 100644 index 000000000..208f01a7c --- /dev/null +++ b/vllm/entrypoints/openai/realtime/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/entrypoints/openai/realtime/api_router.py b/vllm/entrypoints/openai/realtime/api_router.py new file mode 100644 index 000000000..fb7decbd7 --- /dev/null +++ b/vllm/entrypoints/openai/realtime/api_router.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING + +from fastapi import APIRouter, FastAPI, WebSocket + +from vllm.entrypoints.openai.realtime.connection import RealtimeConnection +from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime +from vllm.logger import init_logger + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from argparse import Namespace + + from starlette.datastructures import State + + from vllm.engine.protocol import EngineClient + from vllm.entrypoints.logger import RequestLogger + from vllm.tasks import SupportedTask +else: + RequestLogger = object + +router = APIRouter() + + +@router.websocket("/v1/realtime") +async def realtime_endpoint(websocket: WebSocket): + """WebSocket endpoint for realtime audio transcription. + + Protocol: + 1. Client connects to ws://host/v1/realtime + 2. Server sends session.created event + 3. Client optionally sends session.update with model/params + 4. Client sends input_audio_buffer.commit when ready + 5. Client sends input_audio_buffer.append events with base64 PCM16 chunks + 6. Server processes and sends transcription.delta events + 7. Server sends transcription.done with final text + usage + 8. Repeat from step 5 for next utterance + 9. Optionally, client sends input_audio_buffer.commit with final=True + to signal audio input is finished. Useful when streaming audio files + + Audio format: PCM16, 16kHz, mono, base64-encoded + """ + app = websocket.app + serving = app.state.openai_serving_realtime + + connection = RealtimeConnection(websocket, serving) + await connection.handle_connection() + + +def attach_router(app: FastAPI): + """Attach the realtime router to the FastAPI app.""" + app.include_router(router) + logger.info("Realtime API router attached") + + +def init_realtime_state( + engine_client: "EngineClient", + state: "State", + args: "Namespace", + request_logger: RequestLogger | None, + supported_tasks: tuple["SupportedTask", ...], +): + state.openai_serving_realtime = ( + OpenAIServingRealtime( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + ) + if "realtime" in supported_tasks + else None + ) diff --git a/vllm/entrypoints/openai/realtime/connection.py b/vllm/entrypoints/openai/realtime/connection.py new file mode 100644 index 000000000..6b779c720 --- /dev/null +++ b/vllm/entrypoints/openai/realtime/connection.py @@ -0,0 +1,285 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import base64 +import json +from collections.abc import AsyncGenerator +from http import HTTPStatus +from uuid import uuid4 + +import numpy as np +from fastapi import WebSocket +from starlette.websockets import WebSocketDisconnect + +from vllm import envs +from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo +from vllm.entrypoints.openai.realtime.protocol import ( + ErrorEvent, + InputAudioBufferAppend, + InputAudioBufferCommit, + SessionCreated, + TranscriptionDelta, + TranscriptionDone, +) +from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime +from vllm.exceptions import VLLMValidationError +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class RealtimeConnection: + """Manages WebSocket lifecycle and state for realtime transcription. + + This class handles: + - WebSocket connection lifecycle (accept, receive, send, close) + - Event routing (session.update, append, commit) + - Audio buffering via asyncio.Queue + - Generation task management + - Error handling and cleanup + """ + + def __init__(self, websocket: WebSocket, serving: OpenAIServingRealtime): + self.websocket = websocket + self.connection_id = f"ws-{uuid4()}" + self.serving = serving + self.audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue() + self.generation_task: asyncio.Task | None = None + + self._is_connected = False + self._is_input_finished = False + self._is_model_validated = False + + self._max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB + + async def handle_connection(self): + """Main connection loop.""" + await self.websocket.accept() + logger.debug("WebSocket connection accepted: %s", self.connection_id) + self._is_connected = True + + # Send session created event + await self.send(SessionCreated()) + + try: + while True: + message = await self.websocket.receive_text() + try: + event = json.loads(message) + await self.handle_event(event) + except json.JSONDecodeError: + await self.send_error("Invalid JSON", "invalid_json") + except Exception as e: + logger.exception("Error handling event: %s", e) + await self.send_error(str(e), "processing_error") + except WebSocketDisconnect: + logger.debug("WebSocket disconnected: %s", self.connection_id) + self._is_connected = False + except Exception as e: + logger.exception("Unexpected error in connection: %s", e) + finally: + await self.cleanup() + + def _check_model(self, model: str | None) -> None | ErrorResponse: + if self.serving._is_model_supported(model): + return None + + return self.serving.create_error_response( + message=f"The model `{model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND, + param="model", + ) + + async def handle_event(self, event: dict): + """Route events to handlers. + + Supported event types: + - session.update: Configure model + - input_audio_buffer.append: Add audio chunk to queue + - input_audio_buffer.commit: Start transcription generation + """ + event_type = event.get("type") + if event_type == "session.update": + logger.debug("Session updated: %s", event) + self._check_model(event["model"]) + self._is_model_validated = True + elif event_type == "input_audio_buffer.append": + append_event = InputAudioBufferAppend(**event) + try: + audio_bytes = base64.b64decode(append_event.audio) + # Convert PCM16 bytes to float32 numpy array + audio_array = ( + np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) + / 32768.0 + ) + + if len(audio_array) / 1024**2 > self._max_audio_filesize_mb: + raise VLLMValidationError( + "Maximum file size exceeded", + parameter="audio_filesize_mb", + value=len(audio_array) / 1024**2, + ) + if len(audio_array) == 0: + raise VLLMValidationError("Can't process empty audio.") + + # Put audio chunk in queue + self.audio_queue.put_nowait(audio_array) + + except Exception as e: + logger.error("Failed to decode audio: %s", e) + await self.send_error("Invalid audio data", "invalid_audio") + + elif event_type == "input_audio_buffer.commit": + if not self._is_model_validated: + err_msg = ( + "Model not validated. Make sure to validate the" + " model by sending a session.update event." + ) + await self.send_error( + err_msg, + "model_not_validated", + ) + + commit_event = InputAudioBufferCommit(**event) + # final signals that the audio is finished + if commit_event.final: + self._is_input_finished = True + else: + await self.start_generation() + else: + await self.send_error(f"Unknown event type: {event_type}", "unknown_event") + + async def audio_stream_generator(self) -> AsyncGenerator[np.ndarray, None]: + """Generator that yields audio chunks from the queue.""" + while True: + audio_chunk = await self.audio_queue.get() + if audio_chunk is None: # Sentinel value to stop + break + yield audio_chunk + + async def start_generation(self): + """Start the transcription generation task.""" + if self.generation_task is not None and not self.generation_task.done(): + logger.warning("Generation already in progress, ignoring commit") + return + + # Create audio stream generator + audio_stream = self.audio_stream_generator() + input_stream = asyncio.Queue[list[int]]() + + # Transform to StreamingInput generator + streaming_input_gen = self.serving.transcribe_realtime( + audio_stream, input_stream + ) + + # Start generation task + self.generation_task = asyncio.create_task( + self._run_generation(streaming_input_gen, input_stream) + ) + + async def _run_generation( + self, + streaming_input_gen: AsyncGenerator, + input_stream: asyncio.Queue[list[int]], + ): + """Run the generation and stream results back to the client. + + This method: + 1. Creates sampling parameters from session config + 2. Passes the streaming input generator to engine.generate() + 3. Streams transcription.delta events as text is generated + 4. Sends final transcription.done event with usage stats + 5. Feeds generated token IDs back to input_stream for next iteration + 6. Cleans up the audio queue + """ + request_id = f"rt-{self.connection_id}-{uuid4()}" + full_text = "" + + prompt_token_ids_len: int = 0 + completion_tokens_len: int = 0 + + try: + # Create sampling params + from vllm.sampling_params import RequestOutputKind, SamplingParams + + sampling_params = SamplingParams.from_optional( + temperature=0.0, + max_tokens=1, + output_kind=RequestOutputKind.DELTA, + skip_clone=True, + ) + + # Pass the streaming input generator to the engine + # The engine will consume audio chunks as they arrive and + # stream back transcription results incrementally + result_gen = self.serving.engine_client.generate( + prompt=streaming_input_gen, + sampling_params=sampling_params, + request_id=request_id, + ) + + # Stream results back to client as they're generated + async for output in result_gen: + if output.outputs and len(output.outputs) > 0: + if not prompt_token_ids_len and output.prompt_token_ids: + prompt_token_ids_len = len(output.prompt_token_ids) + + delta = output.outputs[0].text + full_text += delta + + # append output to input + input_stream.put_nowait(list(output.outputs[0].token_ids)) + await self.send(TranscriptionDelta(delta=delta)) + + completion_tokens_len += len(output.outputs[0].token_ids) + + if not self._is_connected: + # finish because websocket connection was killed + break + + if self.audio_queue.empty() and self._is_input_finished: + # finish because client signals that audio input + # is finished + break + + usage = UsageInfo( + prompt_tokens=prompt_token_ids_len, + completion_tokens=completion_tokens_len, + total_tokens=prompt_token_ids_len + completion_tokens_len, + ) + + # Send final completion event + await self.send(TranscriptionDone(text=full_text, usage=usage)) + + # Clear queue for next utterance + while not self.audio_queue.empty(): + self.audio_queue.get_nowait() + + except Exception as e: + logger.exception("Error in generation: %s", e) + await self.send_error(str(e), "processing_error") + + async def send( + self, event: SessionCreated | TranscriptionDelta | TranscriptionDone + ): + """Send event to client.""" + data = event.model_dump_json() + await self.websocket.send_text(data) + + async def send_error(self, message: str, code: str | None = None): + """Send error event to client.""" + error_event = ErrorEvent(error=message, code=code) + await self.websocket.send_text(error_event.model_dump_json()) + + async def cleanup(self): + """Cleanup resources.""" + # Signal audio stream to stop + self.audio_queue.put_nowait(None) + + # Cancel generation task if running + if self.generation_task and not self.generation_task.done(): + self.generation_task.cancel() + + logger.debug("Connection cleanup complete: %s", self.connection_id) diff --git a/vllm/entrypoints/openai/realtime/protocol.py b/vllm/entrypoints/openai/realtime/protocol.py new file mode 100644 index 000000000..25c5cd39d --- /dev/null +++ b/vllm/entrypoints/openai/realtime/protocol.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from typing import Literal + +from pydantic import Field + +from vllm.entrypoints.openai.engine.protocol import ( + OpenAIBaseModel, + UsageInfo, +) +from vllm.utils import random_uuid + +# Client -> Server Events + + +class InputAudioBufferAppend(OpenAIBaseModel): + """Append audio chunk to buffer""" + + type: Literal["input_audio_buffer.append"] = "input_audio_buffer.append" + audio: str # base64-encoded PCM16 @ 16kHz + + +class InputAudioBufferCommit(OpenAIBaseModel): + """Process accumulated audio buffer""" + + type: Literal["input_audio_buffer.commit"] = "input_audio_buffer.commit" + final: bool = False + + +# Server -> Client Events +class SessionUpdate(OpenAIBaseModel): + """Configure session parameters""" + + type: Literal["session.update"] = "session.update" + model: str | None = None + + +class SessionCreated(OpenAIBaseModel): + """Connection established notification""" + + type: Literal["session.created"] = "session.created" + id: str = Field(default_factory=lambda: f"sess-{random_uuid()}") + created: int = Field(default_factory=lambda: int(time.time())) + + +class TranscriptionDelta(OpenAIBaseModel): + """Incremental transcription text""" + + type: Literal["transcription.delta"] = "transcription.delta" + delta: str # Incremental text + + +class TranscriptionDone(OpenAIBaseModel): + """Final transcription with usage stats""" + + type: Literal["transcription.done"] = "transcription.done" + text: str # Complete transcription + usage: UsageInfo | None = None + + +class ErrorEvent(OpenAIBaseModel): + """Error notification""" + + type: Literal["error"] = "error" + error: str + code: str | None = None diff --git a/vllm/entrypoints/openai/realtime/serving.py b/vllm/entrypoints/openai/realtime/serving.py new file mode 100644 index 000000000..8a2d62a37 --- /dev/null +++ b/vllm/entrypoints/openai/realtime/serving.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from collections.abc import AsyncGenerator +from functools import cached_property +from typing import Literal, cast + +import numpy as np + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.inputs.data import PromptType, StreamingInput +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import SupportsRealtime + +logger = init_logger(__name__) + + +class OpenAIServingRealtime(OpenAIServing): + """Realtime audio transcription service via WebSocket streaming. + + Provides streaming audio-to-text transcription by transforming audio chunks + into StreamingInput objects that can be consumed by the engine. + """ + + def __init__( + self, + engine_client: EngineClient, + models: OpenAIServingModels, + *, + request_logger: RequestLogger | None, + log_error_stack: bool = False, + ): + super().__init__( + engine_client=engine_client, + models=models, + request_logger=request_logger, + log_error_stack=log_error_stack, + ) + + self.task_type: Literal["realtime"] = "realtime" + + logger.info("OpenAIServingRealtime initialized for task: %s", self.task_type) + + @cached_property + def model_cls(self) -> type[SupportsRealtime]: + """Get the model class that supports transcription.""" + from vllm.model_executor.model_loader import get_model_cls + + model_cls = get_model_cls(self.model_config) + return cast(type[SupportsRealtime], model_cls) + + async def transcribe_realtime( + self, + audio_stream: AsyncGenerator[np.ndarray, None], + input_stream: asyncio.Queue[list[int]], + ) -> AsyncGenerator[StreamingInput, None]: + """Transform audio stream into StreamingInput for engine.generate(). + + Args: + audio_stream: Async generator yielding float32 numpy audio arrays + input_stream: Queue containing context token IDs from previous + generation outputs. Used for autoregressive multi-turn + processing where each generation's output becomes the context + for the next iteration. + + Yields: + StreamingInput objects containing audio prompts for the engine + """ + + # mypy is being stupid + # TODO(Patrick) - fix this + stream_input_iter = cast( + AsyncGenerator[PromptType, None], + self.model_cls.buffer_realtime_audio( + audio_stream, input_stream, self.model_config + ), + ) + + async for prompt in stream_input_iter: + yield StreamingInput(prompt=prompt) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 1f138a72d..94f56e0b3 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast import torch from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar +from vllm.sampling_params import SamplingParams + if TYPE_CHECKING: from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -357,3 +360,15 @@ def to_enc_dec_tuple_list( (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts ] + + +@dataclass +class StreamingInput: + """Input data for a streaming generation request. + + This is used with generate() to support multi-turn streaming sessions + where inputs are provided via an async generator. + """ + + prompt: PromptType + sampling_params: SamplingParams | None = None diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index ea763afd5..e6ee212af 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable, Iterable, Mapping, MutableSequence +import asyncio +from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, MutableSequence from contextlib import ExitStack, contextmanager, nullcontext from typing import ( TYPE_CHECKING, @@ -1015,6 +1016,37 @@ class SupportsQuant: return None +@runtime_checkable +class SupportsRealtime(Protocol): + """The interface required for all models that support transcription.""" + + supports_realtime: ClassVar[Literal[True]] = True + + @classmethod + async def buffer_realtime_audio( + cls, + audio_stream: AsyncGenerator[np.ndarray, None], + input_stream: asyncio.Queue[list[int]], + model_config: ModelConfig, + ) -> AsyncGenerator[PromptType, None]: ... + + +@overload +def supports_realtime( + model: type[object], +) -> TypeIs[type[SupportsRealtime]]: ... + + +@overload +def supports_realtime(model: object) -> TypeIs[SupportsRealtime]: ... + + +def supports_realtime( + model: type[object] | object, +) -> TypeIs[type[SupportsRealtime]] | TypeIs[SupportsRealtime]: + return getattr(model, "supports_realtime", False) + + @runtime_checkable class SupportsTranscription(Protocol): """The interface required for all models that support transcription.""" diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index c828aa7e5..483f431eb 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import inspect import math from collections.abc import Iterable, Mapping, Sequence from functools import cached_property, partial @@ -20,7 +19,6 @@ from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import ( Audio, AudioEncoder, - TranscriptionFormat, ) from transformers import BatchFeature, TensorType, WhisperConfig from transformers.tokenization_utils_base import TextInput @@ -163,19 +161,10 @@ class VoxtralProcessorAdapter: assert isinstance(audio, np.ndarray) assert audio.ndim == 1 - # pad if necessary - # TODO(Patrick) - remove once mistral-common is bumped - if ( - self._audio_processor.audio_config.transcription_format - != TranscriptionFormat.STREAMING - ): - sig = inspect.signature(self._audio_processor.pad) - if "is_online_streaming" in sig.parameters: - audio = self._audio_processor.pad( - audio, self.sampling_rate, is_online_streaming=False - ) - else: - audio = self._audio_processor.pad(audio, self.sampling_rate) + if not self._audio_processor.audio_config.is_streaming: + audio = self._audio_processor.pad( + audio, self.sampling_rate, is_online_streaming=False + ) audio_tokens = [self.begin_audio_token_id] + [ self.audio_token_id diff --git a/vllm/model_executor/models/voxtral_streaming.py b/vllm/model_executor/models/voxtral_streaming.py index 3d1bb1933..5ff561f73 100644 --- a/vllm/model_executor/models/voxtral_streaming.py +++ b/vllm/model_executor/models/voxtral_streaming.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import math -from collections.abc import Mapping +from collections.abc import AsyncGenerator, Mapping from typing import Literal, cast import numpy as np @@ -12,12 +13,14 @@ from mistral_common.protocol.transcription.request import ( StreamingMode, TranscriptionRequest, ) -from mistral_common.tokens.tokenizers.audio import Audio +from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig -from vllm.inputs.data import PromptType +from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S +from vllm.inputs.data import PromptType, TokensPrompt from vllm.logger import init_logger -from vllm.model_executor.models.interfaces import MultiModalEmbeddings +from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime from vllm.model_executor.models.voxtral import ( VoxtralDummyInputsBuilder, VoxtralForConditionalGeneration, @@ -44,6 +47,8 @@ from .utils import ( logger = init_logger(__name__) +_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30 + class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor): def __init__( @@ -124,29 +129,164 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor: return (base.unsqueeze(1) + offsets).view(-1) +class VoxtralRealtimeBuffer: + def __init__(self, config: AudioConfig) -> None: + self._config = config + + self._look_ahead_in_ms = config.streaming_look_ahead_ms + self._look_back_in_ms = config.streaming_look_back_ms + + self._sampling_rate = self._config.sampling_rate + + self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms) + self._look_back = self._get_len_in_samples(self._look_back_in_ms) + self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate) + + # mutable objects + streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms) + self._start = 0 + self._end = streaming_delay + self._streaming_size + + # always pre-allocate 30 second buffers + self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate + self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32) + self._filled_buffer_len = 0 + + @property + def start_idx(self): + return max(self._start - self._look_back, 0) + + @property + def end_idx(self): + return self._end + self._look_ahead + + @property + def is_audio_complete(self) -> bool: + return self._filled_buffer_len >= self.end_idx + + def _get_len_in_samples(self, len_in_ms: float) -> int: + _len_in_s = self._sampling_rate * len_in_ms / 1000 + assert _len_in_s.is_integer(), _len_in_s + len_in_s = int(_len_in_s) + + return len_in_s + + def _allocate_new_buffer(self) -> None: + # allocate new buffer + new_buffer = np.empty(self._buffer_size, dtype=np.float32) + left_to_copy = max(self._filled_buffer_len - self.start_idx, 0) + + if left_to_copy > 0: + new_buffer[:left_to_copy] = self._buffer[ + self.start_idx : self._filled_buffer_len + ] + + del self._buffer + self._buffer = new_buffer + + self._filled_buffer_len = left_to_copy + self._start = self._look_back + self._end = self._start + self._streaming_size + + def write_audio(self, audio: np.ndarray) -> None: + put_end_idx = self._filled_buffer_len + len(audio) + + if put_end_idx > self._buffer_size: + self._allocate_new_buffer() + + self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = ( + audio + ) + self._filled_buffer_len += len(audio) + + def read_audio(self) -> np.ndarray | None: + if not self.is_audio_complete: + return None + + audio = self._buffer[self.start_idx : self.end_idx] + self._start = self._end + self._end += self._streaming_size + + return audio + + @MULTIMODAL_REGISTRY.register_processor( VoxtralStreamingMultiModalProcessor, info=VoxtralProcessingInfo, dummy_inputs=VoxtralDummyInputsBuilder, ) -class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): +@support_torch_compile +class VoxtralStreamingGeneration(VoxtralForConditionalGeneration, SupportsRealtime): requires_raw_input_tokens = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) + + assert ( + not vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() + ), ( + "Voxtral streaming doesn't support full cudagraphs yet. " + "Please use PIECEWISE." + ) + self.time_embedding: TimeEmbedding = TimeEmbedding( dim=self.config.text_config.hidden_size ) audio_config = self.tokenizer.instruct.audio_encoder.audio_config - _n_delay_tokens = ( - audio_config.frame_rate * audio_config.transcription_delay_ms / 1000 - ) - assert _n_delay_tokens.is_integer(), ( - f"n_delay_tokens must be integer, got {_n_delay_tokens}" - ) + self.n_delay_tokens = audio_config.num_delay_tokens - self.n_delay_tokens = int(_n_delay_tokens) + # for realtime transcription + @classmethod + async def buffer_realtime_audio( + cls, + audio_stream: AsyncGenerator[np.ndarray, None], + input_stream: asyncio.Queue[list[int]], + model_config: ModelConfig, + ) -> AsyncGenerator[PromptType, None]: + tokenizer = cached_tokenizer_from_config(model_config) + audio_encoder = tokenizer.instruct.audio_encoder + config = audio_encoder.audio_config + + buffer = VoxtralRealtimeBuffer(config) + is_first_yield = True + + async for audio in audio_stream: + buffer.write_audio(audio) + + while (new_audio := buffer.read_audio()) is not None: + if is_first_yield: + # make sure that input_stream is empty + assert input_stream.empty() + + audio = Audio(new_audio, config.sampling_rate, format="wav") + + request = TranscriptionRequest( + streaming=StreamingMode.ONLINE, + audio=RawAudio.from_audio(audio), + language=None, + ) + # mistral tokenizer takes care + # of preparing the first prompt inputs + # and does some left-silence padding + # for improved performance + audio_enc = tokenizer.mistral.encode_transcription(request) + + token_ids = audio_enc.tokens + new_audio = audio_enc.audios[0].audio_array + + is_first_yield = False + else: + # pop last element from input_stream + all_outputs = await asyncio.wait_for( + input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S + ) + token_ids = all_outputs[-1:] + + multi_modal_data = {"audio": (new_audio, None)} + yield TokensPrompt( + prompt_token_ids=token_ids, multi_modal_data=multi_modal_data + ) @property def audio_config(self): @@ -205,8 +345,9 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): # sum pool text and audio embeddings inputs_embeds = audio_text_embeds + text_embeds - time_tensor = torch.tensor( - [self.n_delay_tokens], + time_tensor = torch.full( + (1,), + fill_value=self.n_delay_tokens, device=inputs_embeds.device, dtype=inputs_embeds.dtype, ) diff --git a/vllm/tasks.py b/vllm/tasks.py index bd3e5af77..b898bba69 100644 --- a/vllm/tasks.py +++ b/vllm/tasks.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Literal, get_args -GenerationTask = Literal["generate", "transcription"] +GenerationTask = Literal["generate", "transcription", "realtime"] GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask) PoolingTask = Literal[ diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 9f40f41a1..994d5580f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -7,7 +7,6 @@ import time import warnings from collections.abc import AsyncGenerator, Iterable, Mapping from copy import copy -from dataclasses import dataclass from typing import Any import torch @@ -19,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs import PromptType +from vllm.inputs.data import StreamingInput from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry @@ -53,18 +53,6 @@ from vllm.v1.metrics.stats import IterationStats logger = init_logger(__name__) -@dataclass -class StreamingInput: - """Input data for a streaming generation request. - - This is used with generate() to support multi-turn streaming sessions - where inputs are provided via an async generator. - """ - - prompt: PromptType - sampling_params: SamplingParams | None = None - - class InputStreamError(Exception): """Wrapper for errors from the input stream generator. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8e21dea69..3ab7fcad7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -68,6 +68,7 @@ from vllm.model_executor.models.interfaces import ( supports_eagle3, supports_mrope, supports_multimodal_pruning, + supports_realtime, supports_transcription, supports_xdrope, ) @@ -2541,6 +2542,9 @@ class GPUModelRunner( supported_tasks.append("transcription") + if supports_realtime(model): + supported_tasks.append("realtime") + return supported_tasks def get_supported_pooling_tasks(self) -> list[PoolingTask]: