From adcf682fc7d1835d037da331922751e880c8bc25 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Date: Wed, 31 Dec 2025 18:34:18 -0500 Subject: [PATCH] [Audio] Improve Audio Inference Scripts (offline/online) (#29279) Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> --- examples/offline_inference/audio_language.py | 49 ++++++---- .../openai_transcription_client.py | 96 ++++++++++++++++--- 2 files changed, 113 insertions(+), 32 deletions(-) diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 8a182aee9..e9878382b 100755 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -495,27 +495,40 @@ def main(args): temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids ) - mm_data = req_data.multi_modal_data - if not mm_data: - mm_data = {} - if audio_count > 0: - mm_data = { - "audio": [ - asset.audio_and_sample_rate for asset in audio_assets[:audio_count] - ] - } + def get_input(start, end): + mm_data = req_data.multi_modal_data + if not mm_data: + mm_data = {} + if end - start > 0: + mm_data = { + "audio": [ + asset.audio_and_sample_rate for asset in audio_assets[start:end] + ] + } + inputs = {"multi_modal_data": mm_data} + + if req_data.prompt: + inputs["prompt"] = req_data.prompt + else: + inputs["prompt_token_ids"] = req_data.prompt_token_ids + + return inputs + + # Batch inference assert args.num_prompts > 0 - inputs = {"multi_modal_data": mm_data} - - if req_data.prompt: - inputs["prompt"] = req_data.prompt - else: - inputs["prompt_token_ids"] = req_data.prompt_token_ids - - if args.num_prompts > 1: - # Batch inference + if audio_count != 1: + inputs = get_input(0, audio_count) inputs = [inputs] * args.num_prompts + else: + # For single audio input, we need to vary the audio input + # to avoid deduplication in vLLM engine. + inputs = [] + for i in range(args.num_prompts): + start = i % len(audio_assets) + inp = get_input(start, start + 1) + inputs.append(inp) + # Add LoRA request if applicable lora_request = ( req_data.lora_requests * args.num_prompts if req_data.lora_requests else None diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index 0d1d73fb1..966bfd2a4 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -18,6 +18,7 @@ The script performs: 2. Streaming transcription using raw HTTP request to the vLLM server. """ +import argparse import asyncio from openai import AsyncOpenAI, OpenAI @@ -25,14 +26,14 @@ from openai import AsyncOpenAI, OpenAI from vllm.assets.audio import AudioAsset -def sync_openai(audio_path: str, client: OpenAI): +def sync_openai(audio_path: str, client: OpenAI, model: str): """ Perform synchronous transcription using OpenAI-compatible API. """ with open(audio_path, "rb") as f: transcription = client.audio.transcriptions.create( file=f, - model="openai/whisper-large-v3", + model=model, language="en", response_format="json", temperature=0.0, @@ -42,18 +43,18 @@ def sync_openai(audio_path: str, client: OpenAI): repetition_penalty=1.3, ), ) - print("transcription result:", transcription.text) + print("transcription result [sync]:", transcription.text) -async def stream_openai_response(audio_path: str, client: AsyncOpenAI): +async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str): """ Perform asynchronous transcription using OpenAI-compatible API. """ - print("\ntranscription result:", end=" ") + print("\ntranscription result [stream]:", end=" ") with open(audio_path, "rb") as f: transcription = await client.audio.transcriptions.create( file=f, - model="openai/whisper-large-v3", + model=model, language="en", response_format="json", temperature=0.0, @@ -72,7 +73,47 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI): print() # Final newline after stream ends -def main(): +def stream_api_response(audio_path: str, model: str, openai_api_base: str): + """ + Perform streaming transcription using raw HTTP requests to the vLLM API server. + """ + import json + import os + + import requests + + api_url = f"{openai_api_base}/audio/transcriptions" + headers = {"User-Agent": "Transcription-Client"} + with open(audio_path, "rb") as f: + files = {"file": (os.path.basename(audio_path), f)} + data = { + "stream": "true", + "model": model, + "language": "en", + "response_format": "json", + } + + print("\ntranscription result [stream]:", end=" ") + response = requests.post( + api_url, headers=headers, files=files, data=data, stream=True + ) + for chunk in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\n" + ): + if chunk: + data = chunk[len("data: ") :] + data = json.loads(data.decode("utf-8")) + data = data["choices"][0] + delta = data["delta"]["content"] + print(delta, end="", flush=True) + + finish_reason = data.get("finish_reason") + if finish_reason is not None: + print(f"\n[Stream finished reason: {finish_reason}]") + break + + +def main(args): mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path()) winning_call = str(AudioAsset("winning_call").get_local_path()) @@ -84,14 +125,41 @@ def main(): base_url=openai_api_base, ) - sync_openai(mary_had_lamb, client) + model = client.models.list().data[0].id + print(f"Using model: {model}") + + # Run the synchronous function + sync_openai(args.audio_path if args.audio_path else mary_had_lamb, client, model) + # Run the asynchronous function - client = AsyncOpenAI( - api_key=openai_api_key, - base_url=openai_api_base, - ) - asyncio.run(stream_openai_response(winning_call, client)) + if "openai" in model: + client = AsyncOpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + asyncio.run( + stream_openai_response( + args.audio_path if args.audio_path else winning_call, client, model + ) + ) + else: + stream_api_response( + args.audio_path if args.audio_path else winning_call, + model, + openai_api_base, + ) if __name__ == "__main__": - main() + # setup argparser + parser = argparse.ArgumentParser( + description="OpenAI Transcription Client using vLLM API Server" + ) + parser.add_argument( + "--audio_path", + type=str, + default=None, + help="The path to the audio file to transcribe.", + ) + args = parser.parse_args() + main(args)