[Audio] Improve Audio Inference Scripts (offline/online) (#29279)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
@@ -495,27 +495,40 @@ def main(args):
|
||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
|
||||
mm_data = req_data.multi_modal_data
|
||||
if not mm_data:
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
|
||||
]
|
||||
}
|
||||
def get_input(start, end):
|
||||
mm_data = req_data.multi_modal_data
|
||||
if not mm_data:
|
||||
mm_data = {}
|
||||
if end - start > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate for asset in audio_assets[start:end]
|
||||
]
|
||||
}
|
||||
|
||||
inputs = {"multi_modal_data": mm_data}
|
||||
|
||||
if req_data.prompt:
|
||||
inputs["prompt"] = req_data.prompt
|
||||
else:
|
||||
inputs["prompt_token_ids"] = req_data.prompt_token_ids
|
||||
|
||||
return inputs
|
||||
|
||||
# Batch inference
|
||||
assert args.num_prompts > 0
|
||||
inputs = {"multi_modal_data": mm_data}
|
||||
|
||||
if req_data.prompt:
|
||||
inputs["prompt"] = req_data.prompt
|
||||
else:
|
||||
inputs["prompt_token_ids"] = req_data.prompt_token_ids
|
||||
|
||||
if args.num_prompts > 1:
|
||||
# Batch inference
|
||||
if audio_count != 1:
|
||||
inputs = get_input(0, audio_count)
|
||||
inputs = [inputs] * args.num_prompts
|
||||
else:
|
||||
# For single audio input, we need to vary the audio input
|
||||
# to avoid deduplication in vLLM engine.
|
||||
inputs = []
|
||||
for i in range(args.num_prompts):
|
||||
start = i % len(audio_assets)
|
||||
inp = get_input(start, start + 1)
|
||||
inputs.append(inp)
|
||||
|
||||
# Add LoRA request if applicable
|
||||
lora_request = (
|
||||
req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
|
||||
|
||||
@@ -18,6 +18,7 @@ The script performs:
|
||||
2. Streaming transcription using raw HTTP request to the vLLM server.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
@@ -25,14 +26,14 @@ from openai import AsyncOpenAI, OpenAI
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
|
||||
def sync_openai(audio_path: str, client: OpenAI):
|
||||
def sync_openai(audio_path: str, client: OpenAI, model: str):
|
||||
"""
|
||||
Perform synchronous transcription using OpenAI-compatible API.
|
||||
"""
|
||||
with open(audio_path, "rb") as f:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
file=f,
|
||||
model="openai/whisper-large-v3",
|
||||
model=model,
|
||||
language="en",
|
||||
response_format="json",
|
||||
temperature=0.0,
|
||||
@@ -42,18 +43,18 @@ def sync_openai(audio_path: str, client: OpenAI):
|
||||
repetition_penalty=1.3,
|
||||
),
|
||||
)
|
||||
print("transcription result:", transcription.text)
|
||||
print("transcription result [sync]:", transcription.text)
|
||||
|
||||
|
||||
async def stream_openai_response(audio_path: str, client: AsyncOpenAI):
|
||||
async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str):
|
||||
"""
|
||||
Perform asynchronous transcription using OpenAI-compatible API.
|
||||
"""
|
||||
print("\ntranscription result:", end=" ")
|
||||
print("\ntranscription result [stream]:", end=" ")
|
||||
with open(audio_path, "rb") as f:
|
||||
transcription = await client.audio.transcriptions.create(
|
||||
file=f,
|
||||
model="openai/whisper-large-v3",
|
||||
model=model,
|
||||
language="en",
|
||||
response_format="json",
|
||||
temperature=0.0,
|
||||
@@ -72,7 +73,47 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI):
|
||||
print() # Final newline after stream ends
|
||||
|
||||
|
||||
def main():
|
||||
def stream_api_response(audio_path: str, model: str, openai_api_base: str):
|
||||
"""
|
||||
Perform streaming transcription using raw HTTP requests to the vLLM API server.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
api_url = f"{openai_api_base}/audio/transcriptions"
|
||||
headers = {"User-Agent": "Transcription-Client"}
|
||||
with open(audio_path, "rb") as f:
|
||||
files = {"file": (os.path.basename(audio_path), f)}
|
||||
data = {
|
||||
"stream": "true",
|
||||
"model": model,
|
||||
"language": "en",
|
||||
"response_format": "json",
|
||||
}
|
||||
|
||||
print("\ntranscription result [stream]:", end=" ")
|
||||
response = requests.post(
|
||||
api_url, headers=headers, files=files, data=data, stream=True
|
||||
)
|
||||
for chunk in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
|
||||
):
|
||||
if chunk:
|
||||
data = chunk[len("data: ") :]
|
||||
data = json.loads(data.decode("utf-8"))
|
||||
data = data["choices"][0]
|
||||
delta = data["delta"]["content"]
|
||||
print(delta, end="", flush=True)
|
||||
|
||||
finish_reason = data.get("finish_reason")
|
||||
if finish_reason is not None:
|
||||
print(f"\n[Stream finished reason: {finish_reason}]")
|
||||
break
|
||||
|
||||
|
||||
def main(args):
|
||||
mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path())
|
||||
winning_call = str(AudioAsset("winning_call").get_local_path())
|
||||
|
||||
@@ -84,14 +125,41 @@ def main():
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
sync_openai(mary_had_lamb, client)
|
||||
model = client.models.list().data[0].id
|
||||
print(f"Using model: {model}")
|
||||
|
||||
# Run the synchronous function
|
||||
sync_openai(args.audio_path if args.audio_path else mary_had_lamb, client, model)
|
||||
|
||||
# Run the asynchronous function
|
||||
client = AsyncOpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
asyncio.run(stream_openai_response(winning_call, client))
|
||||
if "openai" in model:
|
||||
client = AsyncOpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
asyncio.run(
|
||||
stream_openai_response(
|
||||
args.audio_path if args.audio_path else winning_call, client, model
|
||||
)
|
||||
)
|
||||
else:
|
||||
stream_api_response(
|
||||
args.audio_path if args.audio_path else winning_call,
|
||||
model,
|
||||
openai_api_base,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# setup argparser
|
||||
parser = argparse.ArgumentParser(
|
||||
description="OpenAI Transcription Client using vLLM API Server"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the audio file to transcribe.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user