[Model] Add UltravoxModel and UltravoxConfig (#7615)
This commit is contained in:
@@ -1,138 +1,36 @@
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Union, cast
|
||||
from unittest.mock import patch
|
||||
from typing import Dict, List
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from vllm import ModelRegistry
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.inputs.data import LLMInputs
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
from ...utils import VLLM_PATH
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||
assert chatml_jinja_path.exists()
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
||||
TEST_AUDIO_URLS = [
|
||||
"https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
|
||||
AudioAsset("winning_call").url,
|
||||
]
|
||||
|
||||
|
||||
def server_function(port):
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
def fake_input_mapper(ctx: InputContext, data: object):
|
||||
assert isinstance(data, tuple)
|
||||
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
|
||||
|
||||
# Resample it to 1 sample per second
|
||||
audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
|
||||
return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
|
||||
|
||||
def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "audio" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
audio, sr = multi_modal_data.get("audio")
|
||||
audio_duration = math.ceil(len(audio) / sr)
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||
cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
image_token_id=62, # "_"
|
||||
repeat_count=audio_duration)
|
||||
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", lambda *_, **__: 100)
|
||||
@INPUT_REGISTRY.register_input_processor(fake_input_processor)
|
||||
class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
|
||||
|
||||
def __init__(self, *args, multimodal_config: MultiModalConfig,
|
||||
**kwargs):
|
||||
assert multimodal_config is not None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
processed_audio: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
|
||||
|
||||
with patch(
|
||||
"vllm.entrypoints.chat_utils._mm_token_str",
|
||||
lambda *_, **__: "_"), patch(
|
||||
"vllm.model_executor.models.ModelRegistry.is_multimodal_model"
|
||||
) as mock:
|
||||
mock.return_value = True
|
||||
sys.argv = ["placeholder.py"] + \
|
||||
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
|
||||
"--dtype bfloat16 --enforce-eager --api-key token-abc123 "
|
||||
f"--port {port} --chat-template {chatml_jinja_path} "
|
||||
"--disable-frontend-multiprocessing").split()
|
||||
import runpy
|
||||
runpy.run_module('vllm.entrypoints.openai.api_server',
|
||||
run_name='__main__')
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
port = get_open_port()
|
||||
ctx = torch.multiprocessing.get_context("spawn")
|
||||
server = ctx.Process(target=server_function, args=(port, ))
|
||||
server.start()
|
||||
MAX_SERVER_START_WAIT_S = 60
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url=f"http://localhost:{port}/v1",
|
||||
api_key="token-abc123",
|
||||
)
|
||||
# run health check
|
||||
health_url = f"http://localhost:{port}/health"
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(health_url).status_code == 200:
|
||||
break
|
||||
except Exception as err:
|
||||
result = server.exitcode
|
||||
if result is not None:
|
||||
raise RuntimeError("Server exited unexpectedly.") from err
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > MAX_SERVER_START_WAIT_S:
|
||||
raise RuntimeError("Server failed to start in time.") from err
|
||||
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
server.kill()
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -176,7 +74,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=36, total_tokens=46)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@@ -231,7 +129,7 @@ async def test_single_chat_session_audio_base64encoded(
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=36, total_tokens=46)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
|
||||
Reference in New Issue
Block a user