[Model] Add Ultravox support for multiple audio chunks (#7963)

This commit is contained in:
Peter Salas
2024-09-03 21:38:21 -07:00
committed by GitHub
parent e16fa99a6a
commit 2be8ec6e71
3 changed files with 196 additions and 113 deletions

View File

@@ -16,37 +16,32 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
HF_PLACEHOLDER = "<|audio|>"
@pytest.fixture(scope="session")
def audio_and_sample_rate():
def audio_assets():
from vllm.assets.audio import AudioAsset
return AudioAsset("mary_had_lamb").audio_and_sample_rate
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
@pytest.fixture
def prompts_and_audios(audio_and_sample_rate):
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
def audio(request):
from vllm.assets.audio import AudioAsset
return AudioAsset(request.param)
def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
placeholder = f"{placeholder}\n" * audio_count
vllm_placeholder = "<|reserved_special_token_0|>"
hf_placeholder = "<|audio|>"
question = "What's in the audio?"
vllm_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{vllm_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
hf_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{hf_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
return [(vllm_prompt, hf_prompt, audio_and_sample_rate)]
return tokenizer.apply_chat_template([{
'role': 'user',
'content': f"{placeholder}{question}"
}],
tokenize=False,
add_generation_prompt=True)
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
@@ -134,15 +129,71 @@ def run_test(
)
def run_multi_audio_test(
vllm_runner: Type[VllmRunner],
prompts_and_audios: List[Tuple[str, List[AudioTuple]]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={
"audio":
max((len(audio) for _, audio in prompts_and_audios))
}) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
[prompt for prompt, _ in prompts_and_audios],
max_tokens,
num_logprobs=num_logprobs,
audios=[audios for _, audios in prompts_and_audios])
# The HuggingFace model doesn't support multiple audios yet, so
# just assert that some tokens were generated.
assert all(tokens for tokens, *_ in vllm_outputs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str,
max_tokens: int, num_logprobs: int) -> None:
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
run_test(
hf_runner,
vllm_runner,
prompts_and_audios,
[(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)],
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
vllm_prompt = _get_prompt(len(audio_assets),
"Describe each of the audios above.",
VLLM_PLACEHOLDER)
run_multi_audio_test(
vllm_runner,
[(vllm_prompt, [audio.audio_and_sample_rate
for audio in audio_assets])],
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,