[Core][VLM] Add precise multi-modal placeholder tracking (#8346)
Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
@@ -2,8 +2,10 @@ from typing import List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from transformers import AutoModel, AutoTokenizer, BatchEncoding
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
@@ -17,6 +19,13 @@ AudioTuple = Tuple[np.ndarray, int]
|
||||
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
|
||||
HF_PLACEHOLDER = "<|audio|>"
|
||||
|
||||
CHUNKED_PREFILL_KWARGS = {
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_seqs": 2,
|
||||
# Use a very small limit to exercise chunked prefill.
|
||||
"max_num_batched_tokens": 16
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def audio_assets():
|
||||
@@ -30,6 +39,26 @@ def audio(request):
|
||||
return AudioAsset(request.param)
|
||||
|
||||
|
||||
@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS))
|
||||
def server(request, audio_assets):
|
||||
args = [
|
||||
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
|
||||
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
|
||||
] + [
|
||||
f"--{key.replace('_','-')}={value}"
|
||||
for key, value in request.param.items()
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
def _get_prompt(audio_count, question, placeholder):
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
placeholder = f"{placeholder}\n" * audio_count
|
||||
@@ -68,8 +97,7 @@ def run_test(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm."""
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
@@ -79,11 +107,8 @@ def run_test(
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
with vllm_runner(model, dtype=dtype, enforce_eager=True,
|
||||
**kwargs) as vllm_model:
|
||||
vllm_outputs_per_audio = [
|
||||
vllm_model.generate_greedy_logprobs([vllm_prompt],
|
||||
max_tokens,
|
||||
@@ -135,18 +160,16 @@ def run_multi_audio_test(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={
|
||||
"audio":
|
||||
max((len(audio) for _, audio in prompts_and_audios))
|
||||
}) as vllm_model:
|
||||
},
|
||||
**kwargs) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
[prompt for prompt, _ in prompts_and_audios],
|
||||
max_tokens,
|
||||
@@ -162,8 +185,9 @@ def run_multi_audio_test(
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
|
||||
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
num_logprobs: int, vllm_kwargs: dict) -> None:
|
||||
|
||||
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
|
||||
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
|
||||
@@ -175,7 +199,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
**vllm_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -183,9 +207,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
|
||||
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
max_tokens: int, num_logprobs: int,
|
||||
vllm_kwargs: dict) -> None:
|
||||
|
||||
vllm_prompt = _get_prompt(len(audio_assets),
|
||||
"Describe each of the audios above.",
|
||||
@@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
**vllm_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_online_inference(client, audio_assets):
|
||||
"""Exercises online inference with/without chunked prefill enabled."""
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*[{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio.url
|
||||
}
|
||||
} for audio in audio_assets],
|
||||
{
|
||||
"type":
|
||||
"text",
|
||||
"text":
|
||||
f"What's happening in these {len(audio_assets)} audio clips?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10)
|
||||
|
||||
assert len(chat_completion.choices) == 1
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
|
||||
Reference in New Issue
Block a user