Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,14 +7,14 @@ from vllm.assets.audio import AudioAsset
|
||||
|
||||
@pytest.fixture
|
||||
def mary_had_lamb():
|
||||
path = AudioAsset('mary_had_lamb').get_local_path()
|
||||
path = AudioAsset("mary_had_lamb").get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def winning_call():
|
||||
path = AudioAsset('winning_call').get_local_path()
|
||||
path = AudioAsset("winning_call").get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
@@ -22,6 +22,6 @@ def winning_call():
|
||||
@pytest.fixture
|
||||
def foscolo():
|
||||
# Test translation it->en
|
||||
path = AudioAsset('azacinto_foscolo').get_local_path()
|
||||
path = AudioAsset("azacinto_foscolo").get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
@@ -44,14 +44,15 @@ def run_test(more_args):
|
||||
print(f"Running with: {args}")
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, args,
|
||||
max_wait_seconds=MAX_WAIT_SECONDS) as remote_server:
|
||||
MODEL_NAME, args, max_wait_seconds=MAX_WAIT_SECONDS
|
||||
) as remote_server:
|
||||
url = f"{remote_server.url_for('v1')}/completions"
|
||||
|
||||
model_args = (
|
||||
f"model={MODEL_NAME},"
|
||||
f"base_url={url},"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
|
||||
)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
@@ -60,15 +61,18 @@ def run_test(more_args):
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
assert (
|
||||
measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu()
|
||||
and not current_platform.is_xpu(),
|
||||
reason="V1 currently only supported on CUDA, XPU and TPU")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu()
|
||||
and not current_platform.is_xpu(),
|
||||
reason="V1 currently only supported on CUDA, XPU and TPU",
|
||||
)
|
||||
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ a baseline.
|
||||
This simulates real work usage of the API and makes sure that the frontend and
|
||||
AsyncLLMEngine are working correctly.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import time
|
||||
@@ -45,7 +46,8 @@ async def transcribe_audio(client, tokenizer, y, sr):
|
||||
# NOTE there's no streaming in transcriptions, can't measure ttft
|
||||
latency = end_time - start_time
|
||||
num_output_tokens = len(
|
||||
tokenizer(transcription.text, add_special_tokens=False).input_ids)
|
||||
tokenizer(transcription.text, add_special_tokens=False).input_ids
|
||||
)
|
||||
return latency, num_output_tokens, transcription.text
|
||||
|
||||
|
||||
@@ -73,8 +75,8 @@ async def process_dataset(model, client, data, concurrent_request):
|
||||
for sample in data:
|
||||
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||
task = asyncio.create_task(
|
||||
bound_transcribe(sem, client, tokenizer, (audio, sr),
|
||||
sample["text"]))
|
||||
bound_transcribe(sem, client, tokenizer, (audio, sr), sample["text"])
|
||||
)
|
||||
tasks.append(task)
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
@@ -98,34 +100,35 @@ def print_performance_metrics(results, total_time):
|
||||
|
||||
|
||||
def add_duration(sample):
|
||||
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"]
|
||||
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000
|
||||
y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||
sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000
|
||||
return sample
|
||||
|
||||
|
||||
def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs):
|
||||
def load_hf_dataset(dataset_repo: str, split="validation", **hf_kwargs):
|
||||
## Load and filter the dataset
|
||||
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
|
||||
if 'duration_ms' not in dataset[0]:
|
||||
if "duration_ms" not in dataset[0]:
|
||||
# compute duration to filter
|
||||
dataset = dataset.map(add_duration)
|
||||
|
||||
# Whisper max supported duration
|
||||
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000)
|
||||
dataset = dataset.filter(lambda example: example["duration_ms"] < 30000)
|
||||
return dataset
|
||||
|
||||
|
||||
def run_evaluation(model: str,
|
||||
client,
|
||||
dataset,
|
||||
max_concurrent_reqs: int,
|
||||
n_examples: int = -1,
|
||||
print_metrics: bool = True):
|
||||
def run_evaluation(
|
||||
model: str,
|
||||
client,
|
||||
dataset,
|
||||
max_concurrent_reqs: int,
|
||||
n_examples: int = -1,
|
||||
print_metrics: bool = True,
|
||||
):
|
||||
if n_examples > 0:
|
||||
dataset = dataset.select(range(n_examples))
|
||||
start = time.perf_counter()
|
||||
results = asyncio.run(
|
||||
process_dataset(model, client, dataset, max_concurrent_reqs))
|
||||
results = asyncio.run(process_dataset(model, client, dataset, max_concurrent_reqs))
|
||||
end = time.perf_counter()
|
||||
total_time = end - start
|
||||
print(f"Total Test Time: {total_time:.4f} seconds")
|
||||
@@ -135,8 +138,7 @@ def run_evaluation(model: str,
|
||||
predictions = [res[2] for res in results]
|
||||
references = [res[3] for res in results]
|
||||
wer = load("wer")
|
||||
wer_score = 100 * wer.compute(references=references,
|
||||
predictions=predictions)
|
||||
wer_score = 100 * wer.compute(references=references, predictions=predictions)
|
||||
print("WER:", wer_score)
|
||||
return wer_score
|
||||
|
||||
@@ -145,26 +147,25 @@ def run_evaluation(model: str,
|
||||
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
|
||||
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
|
||||
@pytest.mark.parametrize(
|
||||
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"])
|
||||
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]
|
||||
)
|
||||
# NOTE: Expected WER measured with equivalent hf.transformers args:
|
||||
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
|
||||
@pytest.mark.parametrize("expected_wer", [12.744980])
|
||||
def test_wer_correctness(model_name,
|
||||
dataset_repo,
|
||||
expected_wer,
|
||||
n_examples=-1,
|
||||
max_concurrent_request=None):
|
||||
def test_wer_correctness(
|
||||
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
|
||||
):
|
||||
# TODO refactor to use `ASRDataset`
|
||||
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
|
||||
with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server:
|
||||
dataset = load_hf_dataset(dataset_repo)
|
||||
|
||||
if not max_concurrent_request:
|
||||
# No max concurrency
|
||||
max_concurrent_request = n_examples if n_examples > 0\
|
||||
else len(dataset)
|
||||
max_concurrent_request = n_examples if n_examples > 0 else len(dataset)
|
||||
|
||||
client = remote_server.get_async_client()
|
||||
wer = run_evaluation(model_name, client, dataset,
|
||||
max_concurrent_request, n_examples)
|
||||
wer = run_evaluation(
|
||||
model_name, client, dataset, max_concurrent_request, n_examples
|
||||
)
|
||||
if expected_wer:
|
||||
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
|
||||
|
||||
@@ -44,15 +44,11 @@ async def client(server):
|
||||
ids=["completion", "chat"],
|
||||
argnames=["create_func_gen", "content_body"],
|
||||
argvalues=[
|
||||
(lambda x: x.completions.create, {
|
||||
"prompt": " ".join(['A'] * 10_000)
|
||||
}),
|
||||
(lambda x: x.chat.completions.create, {
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": " ".join(['A'] * 10_000)
|
||||
}]
|
||||
}),
|
||||
(lambda x: x.completions.create, {"prompt": " ".join(["A"] * 10_000)}),
|
||||
(
|
||||
lambda x: x.chat.completions.create,
|
||||
{"messages": [{"role": "user", "content": " ".join(["A"] * 10_000)}]},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_with_and_without_truncate(
|
||||
@@ -65,15 +61,15 @@ async def test_with_and_without_truncate(
|
||||
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
||||
|
||||
num_requests = 10
|
||||
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
|
||||
(num_requests - num_requests // 2))
|
||||
truncate_prompt_tokens = [1000] * (num_requests // 2) + [None] * (
|
||||
num_requests - num_requests // 2
|
||||
)
|
||||
random.shuffle(truncate_prompt_tokens)
|
||||
|
||||
bodies = [{
|
||||
**body, "extra_body": {
|
||||
'truncate_prompt_tokens': t
|
||||
}
|
||||
} for t in truncate_prompt_tokens]
|
||||
bodies = [
|
||||
{**body, "extra_body": {"truncate_prompt_tokens": t}}
|
||||
for t in truncate_prompt_tokens
|
||||
]
|
||||
|
||||
async def get_status_code(**kwargs):
|
||||
try:
|
||||
|
||||
@@ -56,24 +56,18 @@ def base64_encoded_audio() -> dict[str, str]:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
model_name: str, audio_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_single_chat_session_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -82,13 +76,15 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
temperature=0.0,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@@ -110,56 +106,52 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
audio_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": audio_url
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_error_on_invalid_audio_url_type(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": audio_url},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# audio_url should be a dict {"url": "some url"}, not directly a string
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0)
|
||||
_ = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_single_chat_session_audio_base64encoded(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
||||
base64_encoded_audio: dict[str, str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url":
|
||||
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
audio_url: str,
|
||||
base64_encoded_audio: dict[str, str],
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -168,13 +160,15 @@ async def test_single_chat_session_audio_base64encoded(
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
temperature=0.0,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@@ -198,25 +192,26 @@ async def test_single_chat_session_audio_base64encoded(
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
|
||||
async def test_single_chat_session_input_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
||||
base64_encoded_audio: dict[str, str]):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": base64_encoded_audio[audio_url],
|
||||
"format": "wav"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
audio_url: str,
|
||||
base64_encoded_audio: dict[str, str],
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": base64_encoded_audio[audio_url],
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -224,13 +219,15 @@ async def test_single_chat_session_input_audio(
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@@ -252,24 +249,18 @@ async def test_single_chat_session_input_audio(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||
async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
||||
model_name: str, audio_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_chat_streaming_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -309,27 +300,27 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||
async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
||||
model_name: str, audio_url: str,
|
||||
base64_encoded_audio: dict[str,
|
||||
str]):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": base64_encoded_audio[audio_url],
|
||||
"format": "wav"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_chat_streaming_input_audio(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
audio_url: str,
|
||||
base64_encoded_audio: dict[str, str],
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": base64_encoded_audio[audio_url],
|
||||
"format": "wav",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -369,26 +360,23 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize(
|
||||
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]])
|
||||
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
audio_urls: list[str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*({
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_url
|
||||
}
|
||||
} for audio_url in audio_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's happening in this audio?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
"audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]
|
||||
)
|
||||
async def test_multi_audio_input(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str]
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
{"type": "audio_url", "audio_url": {"url": audio_url}}
|
||||
for audio_url in audio_urls
|
||||
),
|
||||
{"type": "text", "text": "What's happening in this audio?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
if len(audio_urls) > MAXIMUM_AUDIOS:
|
||||
with pytest.raises(openai.BadRequestError): # test multi-audio input
|
||||
|
||||
@@ -16,9 +16,9 @@ from ...utils import RemoteOpenAIServer
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@pytest.fixture(scope="module")
|
||||
def server_args(request: pytest.FixtureRequest) -> list[str]:
|
||||
""" Provide extra arguments to the server via indirect parametrization
|
||||
"""Provide extra arguments to the server via indirect parametrization
|
||||
|
||||
Usage:
|
||||
|
||||
@@ -80,8 +80,10 @@ async def client(server):
|
||||
"server_args",
|
||||
[
|
||||
pytest.param([], id="default-frontend-multiprocessing"),
|
||||
pytest.param(["--disable-frontend-multiprocessing"],
|
||||
id="disable-frontend-multiprocessing")
|
||||
pytest.param(
|
||||
["--disable-frontend-multiprocessing"],
|
||||
id="disable-frontend-multiprocessing",
|
||||
),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@@ -97,8 +99,10 @@ async def test_show_version(server: RemoteOpenAIServer):
|
||||
"server_args",
|
||||
[
|
||||
pytest.param([], id="default-frontend-multiprocessing"),
|
||||
pytest.param(["--disable-frontend-multiprocessing"],
|
||||
id="disable-frontend-multiprocessing")
|
||||
pytest.param(
|
||||
["--disable-frontend-multiprocessing"],
|
||||
id="disable-frontend-multiprocessing",
|
||||
),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@@ -112,11 +116,13 @@ async def test_check_health(server: RemoteOpenAIServer):
|
||||
@pytest.mark.parametrize(
|
||||
"server_args",
|
||||
[
|
||||
pytest.param(["--max-model-len", "10100"],
|
||||
id="default-frontend-multiprocessing"),
|
||||
pytest.param(
|
||||
["--max-model-len", "10100"], id="default-frontend-multiprocessing"
|
||||
),
|
||||
pytest.param(
|
||||
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
|
||||
id="disable-frontend-multiprocessing")
|
||||
id="disable-frontend-multiprocessing",
|
||||
),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@@ -131,14 +137,16 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
||||
# Request about 2 million tokens
|
||||
for _ in range(200):
|
||||
task = asyncio.create_task(
|
||||
client.chat.completions.create(messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10000,
|
||||
extra_body={"min_tokens": 10000}))
|
||||
client.chat.completions.create(
|
||||
messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10000,
|
||||
extra_body={"min_tokens": 10000},
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.ALL_COMPLETED)
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
# Make sure all requests were sent to the server and timed out
|
||||
# (We don't want to hide other errors like 400s that would invalidate this
|
||||
@@ -151,16 +159,15 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
||||
# If the server had not cancelled all the other requests, then it would not
|
||||
# be able to respond to this one within the timeout
|
||||
client = server.get_async_client(timeout=5)
|
||||
response = await client.chat.completions.create(messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10)
|
||||
response = await client.chat.completions.create(
|
||||
messages=chat_input, model=MODEL_NAME, max_tokens=10
|
||||
)
|
||||
|
||||
assert len(response.choices) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
||||
|
||||
chat_input = [{"role": "user", "content": "Write a long story"}]
|
||||
client = server.get_async_client()
|
||||
|
||||
@@ -169,17 +176,13 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer):
|
||||
messages=chat_input,
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10000,
|
||||
extra_headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded"
|
||||
})
|
||||
extra_headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"server_args",
|
||||
[
|
||||
pytest.param(["--enable-server-load-tracking"],
|
||||
id="enable-server-load-tracking")
|
||||
],
|
||||
[pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@@ -202,7 +205,8 @@ async def test_server_load(server: RemoteOpenAIServer):
|
||||
|
||||
# Start the completion request in a background thread.
|
||||
completion_future = asyncio.create_task(
|
||||
asyncio.to_thread(make_long_completion_request))
|
||||
asyncio.to_thread(make_long_completion_request)
|
||||
)
|
||||
|
||||
# Give a short delay to ensure the request has started.
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,7 @@ def server():
|
||||
"--max-model-len",
|
||||
"4080",
|
||||
"--max-logprobs", # test prompt_logprobs equal to -1
|
||||
"151936"
|
||||
"151936",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@@ -46,27 +46,26 @@ class TestCase(NamedTuple):
|
||||
"test_case",
|
||||
[
|
||||
TestCase(model_name=MODEL_NAME, echo=True),
|
||||
TestCase(model_name=MODEL_NAME, echo=False)
|
||||
TestCase(model_name=MODEL_NAME, echo=False),
|
||||
],
|
||||
)
|
||||
async def test_chat_session_with_echo_and_continue_final_message(
|
||||
client: openai.AsyncOpenAI, test_case: TestCase):
|
||||
client: openai.AsyncOpenAI, test_case: TestCase
|
||||
):
|
||||
saying: str = "Here is a common saying about apple. An apple a day, keeps"
|
||||
# test echo with continue_final_message parameter
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=test_case.model_name,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "tell me a common saying"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": saying
|
||||
}],
|
||||
messages=[
|
||||
{"role": "user", "content": "tell me a common saying"},
|
||||
{"role": "assistant", "content": saying},
|
||||
],
|
||||
extra_body={
|
||||
"echo": test_case.echo,
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False
|
||||
})
|
||||
"add_generation_prompt": False,
|
||||
},
|
||||
)
|
||||
assert chat_completion.id is not None
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
@@ -83,13 +82,10 @@ async def test_chat_session_with_echo_and_continue_final_message(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_logprobs(client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Beijing is the capital of which country?"
|
||||
}]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Beijing is the capital of which country?"},
|
||||
]
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
@@ -103,13 +99,10 @@ async def test_prompt_logprobs(client: openai.AsyncOpenAI):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_logprobs(client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Beijing is the capital of which country?"
|
||||
}]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Beijing is the capital of which country?"},
|
||||
]
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
|
||||
@@ -49,10 +49,7 @@ async def test_chat_logit_bias_valid(client):
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Testing valid logit bias"
|
||||
}],
|
||||
messages=[{"role": "user", "content": "Testing valid logit bias"}],
|
||||
max_tokens=5,
|
||||
logit_bias={str(valid_token_id): 1.0},
|
||||
)
|
||||
@@ -69,10 +66,7 @@ async def test_chat_logit_bias_invalid(client):
|
||||
with pytest.raises(openai.BadRequestError) as excinfo:
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Testing invalid logit bias"
|
||||
}],
|
||||
messages=[{"role": "user", "content": "Testing invalid logit bias"}],
|
||||
max_tokens=5,
|
||||
logit_bias={str(invalid_token_id): 1.0},
|
||||
)
|
||||
|
||||
@@ -4,8 +4,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
load_chat_template)
|
||||
from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@@ -17,48 +16,54 @@ assert chatml_jinja_path.exists()
|
||||
|
||||
# Define models, templates, and their corresponding expected outputs
|
||||
MODEL_TEMPLATE_GENERATION_OUTPUT = [
|
||||
("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user
|
||||
(
|
||||
"facebook/opt-125m",
|
||||
chatml_jinja_path,
|
||||
True,
|
||||
False,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of<|im_end|>
|
||||
<|im_start|>assistant
|
||||
"""),
|
||||
("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user
|
||||
""",
|
||||
),
|
||||
(
|
||||
"facebook/opt-125m",
|
||||
chatml_jinja_path,
|
||||
False,
|
||||
False,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of"""),
|
||||
("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user
|
||||
What is the capital of""",
|
||||
),
|
||||
(
|
||||
"facebook/opt-125m",
|
||||
chatml_jinja_path,
|
||||
False,
|
||||
True,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of<|im_end|>
|
||||
<|im_start|>assistant
|
||||
The capital of"""),
|
||||
The capital of""",
|
||||
),
|
||||
]
|
||||
|
||||
TEST_MESSAGES = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'Hi there!'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'What is the capital of'
|
||||
},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "What is the capital of"},
|
||||
]
|
||||
ASSISTANT_MESSAGE_TO_CONTINUE = {
|
||||
'role': 'assistant',
|
||||
'content': 'The capital of'
|
||||
}
|
||||
ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"}
|
||||
|
||||
|
||||
def test_load_chat_template():
|
||||
@@ -68,8 +73,11 @@ def test_load_chat_template():
|
||||
# Test assertions
|
||||
assert template_content is not None
|
||||
# Hard coded value for template_chatml.jinja
|
||||
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
|
||||
assert (
|
||||
template_content
|
||||
== """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
|
||||
) # noqa: E501
|
||||
|
||||
|
||||
def test_no_load_chat_template_filelike():
|
||||
@@ -91,9 +99,11 @@ def test_no_load_chat_template_literallike():
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,template,add_generation_prompt,continue_final_message,expected_output",
|
||||
MODEL_TEMPLATE_GENERATION_OUTPUT)
|
||||
def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
continue_final_message, expected_output):
|
||||
MODEL_TEMPLATE_GENERATION_OUTPUT,
|
||||
)
|
||||
def test_get_gen_prompt(
|
||||
model, template, add_generation_prompt, continue_final_message, expected_output
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
@@ -106,7 +116,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
|
||||
# Initialize the tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
@@ -119,7 +130,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
mock_request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
|
||||
if continue_final_message else TEST_MESSAGES,
|
||||
if continue_final_message
|
||||
else TEST_MESSAGES,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
)
|
||||
@@ -138,4 +150,5 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
# Test assertion
|
||||
assert result == expected_output, (
|
||||
f"The generated prompt does not match the expected output for "
|
||||
f"model {model} and template {template}")
|
||||
f"model {model} and template {template}"
|
||||
)
|
||||
|
||||
@@ -14,9 +14,14 @@ MODEL_NAME = "Qwen/QwQ-32B"
|
||||
@pytest.fixture(scope="module")
|
||||
def server(): # noqa: F811
|
||||
args = [
|
||||
"--max-model-len", "8192", "--enforce-eager", "--reasoning-parser",
|
||||
"deepseek_r1", "--enable-auto-tool-choice", "--tool-call-parser",
|
||||
"hermes"
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
"--reasoning-parser",
|
||||
"deepseek_r1",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@@ -29,50 +34,44 @@ async def client(server):
|
||||
yield async_client
|
||||
|
||||
|
||||
TOOLS = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'San Francisco'"
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
"required": ["city", "state", "unit"]
|
||||
}
|
||||
},
|
||||
}
|
||||
}]
|
||||
]
|
||||
|
||||
MESSAGES = [{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||
}]
|
||||
MESSAGES = [
|
||||
{"role": "user", "content": "Hi! How are you doing today?"},
|
||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
FUNC_NAME = "get_current_weather"
|
||||
FUNC_ARGS = """{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}"""
|
||||
@@ -105,9 +104,7 @@ def extract_reasoning_and_calls(chunks: list):
|
||||
|
||||
# test streaming
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_streaming_of_tool_and_reasoning(
|
||||
client: openai.AsyncOpenAI):
|
||||
|
||||
async def test_chat_streaming_of_tool_and_reasoning(client: openai.AsyncOpenAI):
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES,
|
||||
@@ -120,8 +117,7 @@ async def test_chat_streaming_of_tool_and_reasoning(
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
reasoning_content, arguments, function_names = extract_reasoning_and_calls(
|
||||
chunks)
|
||||
reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
|
||||
assert len(reasoning_content) > 0
|
||||
assert len(function_names) > 0 and function_names[0] == FUNC_NAME
|
||||
assert len(arguments) > 0 and arguments[0] == FUNC_ARGS
|
||||
@@ -130,7 +126,6 @@ async def test_chat_streaming_of_tool_and_reasoning(
|
||||
# test full generate
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI):
|
||||
|
||||
tool_calls = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES,
|
||||
@@ -140,7 +135,5 @@ async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI):
|
||||
)
|
||||
|
||||
assert len(tool_calls.choices[0].message.reasoning_content) > 0
|
||||
assert tool_calls.choices[0].message.tool_calls[0].function.name \
|
||||
== FUNC_NAME
|
||||
assert tool_calls.choices[0].message.tool_calls[0].function.arguments \
|
||||
== FUNC_ARGS
|
||||
assert tool_calls.choices[0].message.tool_calls[0].function.name == FUNC_NAME
|
||||
assert tool_calls.choices[0].message.tool_calls[0].function.arguments == FUNC_ARGS
|
||||
|
||||
@@ -40,7 +40,8 @@ async def client(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_stream_options_and_logprobs_with_long_prompts(
|
||||
client: openai.AsyncOpenAI):
|
||||
client: openai.AsyncOpenAI,
|
||||
):
|
||||
# Test stream with long prompt
|
||||
prompt = "What is the capital of France?" * 400
|
||||
|
||||
@@ -62,8 +63,9 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts(
|
||||
async for chunk in stream:
|
||||
assert chunk.usage.prompt_tokens >= 0
|
||||
assert chunk.usage.completion_tokens >= 0
|
||||
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
|
||||
chunk.usage.completion_tokens)
|
||||
assert chunk.usage.total_tokens == (
|
||||
chunk.usage.prompt_tokens + chunk.usage.completion_tokens
|
||||
)
|
||||
if not finished:
|
||||
tokens_received += 1
|
||||
assert chunk.choices[0].text
|
||||
@@ -77,15 +79,13 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_stream_options_and_logprobs_with_long_prompts(
|
||||
client: openai.AsyncOpenAI):
|
||||
client: openai.AsyncOpenAI,
|
||||
):
|
||||
# Test stream with long prompt
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?" * 400
|
||||
}]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the capital of France?" * 400},
|
||||
]
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
@@ -106,8 +106,9 @@ async def test_chat_completion_stream_options_and_logprobs_with_long_prompts(
|
||||
async for chunk in stream:
|
||||
assert chunk.usage.prompt_tokens >= 0
|
||||
assert chunk.usage.completion_tokens >= 0
|
||||
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
|
||||
chunk.usage.completion_tokens)
|
||||
assert chunk.usage.total_tokens == (
|
||||
chunk.usage.prompt_tokens + chunk.usage.completion_tokens
|
||||
)
|
||||
|
||||
if not finished:
|
||||
if chunk.choices[0].delta.content == "":
|
||||
|
||||
@@ -5,8 +5,7 @@ import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@@ -15,7 +14,7 @@ from ...utils import VLLM_PATH
|
||||
LORA_MODULE = {
|
||||
"name": "module2",
|
||||
"path": "/path/to/module2",
|
||||
"base_model_name": "llama"
|
||||
"base_model_name": "llama",
|
||||
}
|
||||
CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja"
|
||||
assert CHATML_JINJA_PATH.exists()
|
||||
@@ -31,45 +30,51 @@ def serve_parser():
|
||||
def test_config_arg_parsing(serve_parser, cli_config_file):
|
||||
args = serve_parser.parse_args([])
|
||||
assert args.port == 8000
|
||||
args = serve_parser.parse_args(['--config', cli_config_file])
|
||||
args = serve_parser.parse_args(["--config", cli_config_file])
|
||||
assert args.port == 12312
|
||||
args = serve_parser.parse_args([
|
||||
'--config',
|
||||
cli_config_file,
|
||||
'--port',
|
||||
'9000',
|
||||
])
|
||||
args = serve_parser.parse_args(
|
||||
[
|
||||
"--config",
|
||||
cli_config_file,
|
||||
"--port",
|
||||
"9000",
|
||||
]
|
||||
)
|
||||
assert args.port == 9000
|
||||
args = serve_parser.parse_args([
|
||||
'--port',
|
||||
'9000',
|
||||
'--config',
|
||||
cli_config_file,
|
||||
])
|
||||
args = serve_parser.parse_args(
|
||||
[
|
||||
"--port",
|
||||
"9000",
|
||||
"--config",
|
||||
cli_config_file,
|
||||
]
|
||||
)
|
||||
assert args.port == 9000
|
||||
|
||||
|
||||
### Tests for LoRA module parsing
|
||||
def test_valid_key_value_format(serve_parser):
|
||||
# Test old format: name=path
|
||||
args = serve_parser.parse_args([
|
||||
'--lora-modules',
|
||||
'module1=/path/to/module1',
|
||||
])
|
||||
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
|
||||
args = serve_parser.parse_args(
|
||||
[
|
||||
"--lora-modules",
|
||||
"module1=/path/to/module1",
|
||||
]
|
||||
)
|
||||
expected = [LoRAModulePath(name="module1", path="/path/to/module1")]
|
||||
assert args.lora_modules == expected
|
||||
|
||||
|
||||
def test_valid_json_format(serve_parser):
|
||||
# Test valid JSON format input
|
||||
args = serve_parser.parse_args([
|
||||
'--lora-modules',
|
||||
json.dumps(LORA_MODULE),
|
||||
])
|
||||
args = serve_parser.parse_args(
|
||||
[
|
||||
"--lora-modules",
|
||||
json.dumps(LORA_MODULE),
|
||||
]
|
||||
)
|
||||
expected = [
|
||||
LoRAModulePath(name='module2',
|
||||
path='/path/to/module2',
|
||||
base_model_name='llama')
|
||||
LoRAModulePath(name="module2", path="/path/to/module2", base_model_name="llama")
|
||||
]
|
||||
assert args.lora_modules == expected
|
||||
|
||||
@@ -77,47 +82,53 @@ def test_valid_json_format(serve_parser):
|
||||
def test_invalid_json_format(serve_parser):
|
||||
# Test invalid JSON format input, missing closing brace
|
||||
with pytest.raises(SystemExit):
|
||||
serve_parser.parse_args([
|
||||
'--lora-modules', '{"name": "module3", "path": "/path/to/module3"'
|
||||
])
|
||||
serve_parser.parse_args(
|
||||
["--lora-modules", '{"name": "module3", "path": "/path/to/module3"']
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_type_error(serve_parser):
|
||||
# Test type error when values are not JSON or key=value
|
||||
with pytest.raises(SystemExit):
|
||||
serve_parser.parse_args([
|
||||
'--lora-modules',
|
||||
'invalid_format' # This is not JSON or key=value format
|
||||
])
|
||||
serve_parser.parse_args(
|
||||
[
|
||||
"--lora-modules",
|
||||
"invalid_format", # This is not JSON or key=value format
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_json_field(serve_parser):
|
||||
# Test valid JSON format but missing required fields
|
||||
with pytest.raises(SystemExit):
|
||||
serve_parser.parse_args([
|
||||
'--lora-modules',
|
||||
'{"name": "module4"}' # Missing required 'path' field
|
||||
])
|
||||
serve_parser.parse_args(
|
||||
[
|
||||
"--lora-modules",
|
||||
'{"name": "module4"}', # Missing required 'path' field
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_empty_values(serve_parser):
|
||||
# Test when no LoRA modules are provided
|
||||
args = serve_parser.parse_args(['--lora-modules', ''])
|
||||
args = serve_parser.parse_args(["--lora-modules", ""])
|
||||
assert args.lora_modules == []
|
||||
|
||||
|
||||
def test_multiple_valid_inputs(serve_parser):
|
||||
# Test multiple valid inputs (both old and JSON format)
|
||||
args = serve_parser.parse_args([
|
||||
'--lora-modules',
|
||||
'module1=/path/to/module1',
|
||||
json.dumps(LORA_MODULE),
|
||||
])
|
||||
args = serve_parser.parse_args(
|
||||
[
|
||||
"--lora-modules",
|
||||
"module1=/path/to/module1",
|
||||
json.dumps(LORA_MODULE),
|
||||
]
|
||||
)
|
||||
expected = [
|
||||
LoRAModulePath(name='module1', path='/path/to/module1'),
|
||||
LoRAModulePath(name='module2',
|
||||
path='/path/to/module2',
|
||||
base_model_name='llama')
|
||||
LoRAModulePath(name="module1", path="/path/to/module1"),
|
||||
LoRAModulePath(
|
||||
name="module2", path="/path/to/module2", base_model_name="llama"
|
||||
),
|
||||
]
|
||||
assert args.lora_modules == expected
|
||||
|
||||
@@ -133,40 +144,46 @@ def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser):
|
||||
|
||||
def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser):
|
||||
"""Ensure validation passes with tool choice enabled with a call parser"""
|
||||
args = serve_parser.parse_args(args=[
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"mistral",
|
||||
])
|
||||
args = serve_parser.parse_args(
|
||||
args=[
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"mistral",
|
||||
]
|
||||
)
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
|
||||
def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser):
|
||||
"""Ensure validation fails if reasoning is enabled with auto tool choice"""
|
||||
args = serve_parser.parse_args(args=[
|
||||
"--enable-auto-tool-choice",
|
||||
"--reasoning-parser",
|
||||
"deepseek_r1",
|
||||
])
|
||||
args = serve_parser.parse_args(
|
||||
args=[
|
||||
"--enable-auto-tool-choice",
|
||||
"--reasoning-parser",
|
||||
"deepseek_r1",
|
||||
]
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
|
||||
def test_passes_with_reasoning_parser(serve_parser):
|
||||
"""Ensure validation passes if reasoning is enabled
|
||||
"""Ensure validation passes if reasoning is enabled
|
||||
with a reasoning parser"""
|
||||
args = serve_parser.parse_args(args=[
|
||||
"--reasoning-parser",
|
||||
"deepseek_r1",
|
||||
])
|
||||
args = serve_parser.parse_args(
|
||||
args=[
|
||||
"--reasoning-parser",
|
||||
"deepseek_r1",
|
||||
]
|
||||
)
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
|
||||
def test_chat_template_validation_for_happy_paths(serve_parser):
|
||||
"""Ensure validation passes if the chat template exists"""
|
||||
args = serve_parser.parse_args(
|
||||
args=["--chat-template",
|
||||
CHATML_JINJA_PATH.absolute().as_posix()])
|
||||
args=["--chat-template", CHATML_JINJA_PATH.absolute().as_posix()]
|
||||
)
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
|
||||
@@ -179,8 +196,14 @@ def test_chat_template_validation_for_sad_paths(serve_parser):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cli_args, expected_middleware",
|
||||
[(["--middleware", "middleware1", "--middleware", "middleware2"
|
||||
], ["middleware1", "middleware2"]), ([], [])])
|
||||
[
|
||||
(
|
||||
["--middleware", "middleware1", "--middleware", "middleware2"],
|
||||
["middleware1", "middleware2"],
|
||||
),
|
||||
([], []),
|
||||
],
|
||||
)
|
||||
def test_middleware(serve_parser, cli_args, expected_middleware):
|
||||
"""Ensure multiple middleware args are parsed properly"""
|
||||
args = serve_parser.parse_args(args=cli_args)
|
||||
|
||||
@@ -12,7 +12,6 @@ MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
class TestWorkerExtension:
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""Test non-pydantic return type."""
|
||||
return MODEL_NAME
|
||||
@@ -41,20 +40,18 @@ def server():
|
||||
"tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension",
|
||||
]
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME,
|
||||
args,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE": "1",
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
MODEL_NAME,
|
||||
args,
|
||||
env_dict={"VLLM_SERVER_DEV_MODE": "1", "CUDA_VISIBLE_DEVICES": "0"},
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def test_get_model_name(server):
|
||||
"""Test basic response"""
|
||||
response = requests.post(server.url_for("collective_rpc"),
|
||||
json={"method": "get_model_name"})
|
||||
response = requests.post(
|
||||
server.url_for("collective_rpc"), json={"method": "get_model_name"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert "results" in results
|
||||
@@ -63,8 +60,9 @@ def test_get_model_name(server):
|
||||
|
||||
def test_return_none(server):
|
||||
"""Test return none"""
|
||||
response = requests.post(server.url_for("collective_rpc"),
|
||||
json={"method": "return_none"})
|
||||
response = requests.post(
|
||||
server.url_for("collective_rpc"), json={"method": "return_none"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert results["results"] == [None]
|
||||
@@ -74,12 +72,10 @@ def test_echo_args_kwargs(server):
|
||||
"""Test args, kwargs, and dict response"""
|
||||
args = ["arg1", "arg2"]
|
||||
kwargs = {"key1": "value1", "key2": "value2"}
|
||||
response = requests.post(server.url_for("collective_rpc"),
|
||||
json={
|
||||
"method": "echo_args_kwargs",
|
||||
"args": args,
|
||||
"kwargs": kwargs
|
||||
})
|
||||
response = requests.post(
|
||||
server.url_for("collective_rpc"),
|
||||
json={"method": "echo_args_kwargs", "args": args, "kwargs": kwargs},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
result = results["results"][0]
|
||||
|
||||
@@ -25,15 +25,12 @@ tools = [
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'Vienna'",
|
||||
"description": "The city to find the weather for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
"type": "string",
|
||||
"description": "The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
@@ -62,8 +59,7 @@ tools = [
|
||||
"include_forecast": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description":
|
||||
"Whether to include a 24-hour forecast",
|
||||
"description": "Whether to include a 24-hour forecast",
|
||||
"title": "Include Forecast",
|
||||
},
|
||||
"language": {
|
||||
@@ -89,21 +85,16 @@ tools = [
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'Vienna'",
|
||||
"description": "The city to get the forecast for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
"type": "string",
|
||||
"description": "The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
"type": "integer",
|
||||
"description": "Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
@@ -118,19 +109,11 @@ tools = [
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi! How are you doing today?"},
|
||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Berlin and the "\
|
||||
"content": "Can you tell me what the current weather is in Berlin and the "
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
@@ -150,7 +133,7 @@ def server(): # noqa: F811
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
"--gpu-memory-utilization",
|
||||
"0.4"
|
||||
"0.4",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@@ -166,18 +149,22 @@ async def client(server):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.parametrize("tool_choice", [
|
||||
"auto", "required", {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather"
|
||||
}
|
||||
}
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"tool_choice",
|
||||
[
|
||||
"auto",
|
||||
"required",
|
||||
{"type": "function", "function": {"name": "get_current_weather"}},
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_thinking", [True, False])
|
||||
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
||||
stream: bool, tool_choice: Union[str, dict],
|
||||
enable_thinking: bool):
|
||||
async def test_function_tool_use(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
stream: bool,
|
||||
tool_choice: Union[str, dict],
|
||||
enable_thinking: bool,
|
||||
):
|
||||
if not stream:
|
||||
# Non-streaming test
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -185,16 +172,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {
|
||||
"enable_thinking": enable_thinking
|
||||
}
|
||||
})
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}},
|
||||
)
|
||||
if enable_thinking:
|
||||
assert chat_completion.choices[0].message.\
|
||||
reasoning_content is not None
|
||||
assert chat_completion.choices[0].message.\
|
||||
reasoning_content != ""
|
||||
assert chat_completion.choices[0].message.reasoning_content is not None
|
||||
assert chat_completion.choices[0].message.reasoning_content != ""
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
else:
|
||||
@@ -205,11 +187,8 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {
|
||||
"enable_thinking": enable_thinking
|
||||
}
|
||||
})
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}},
|
||||
)
|
||||
|
||||
output = []
|
||||
async for chunk in output_stream:
|
||||
@@ -237,12 +216,11 @@ def k2_server(): # noqa: F811
|
||||
]
|
||||
# hack to test kimi_k2 tool use tool_id format.
|
||||
# avoid error in is_deepseek_mla check by setting kv_lora_rank=null
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
args,
|
||||
override_hf_configs={
|
||||
"model_type": 'kimi_k2',
|
||||
'kv_lora_rank': None
|
||||
}) as remote_server:
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME,
|
||||
args,
|
||||
override_hf_configs={"model_type": "kimi_k2", "kv_lora_rank": None},
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -256,20 +234,20 @@ async def k2_client(k2_server):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.parametrize("tool_choice", ["required"])
|
||||
async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str,
|
||||
stream: bool, tool_choice: str):
|
||||
|
||||
async def test_tool_id_kimi_k2(
|
||||
k2_client: openai.AsyncOpenAI, model_name: str, stream: bool, tool_choice: str
|
||||
):
|
||||
if not stream:
|
||||
# Non-streaming test
|
||||
chat_completion = await k2_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice)
|
||||
messages=messages, model=model_name, tools=tools, tool_choice=tool_choice
|
||||
)
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
assert chat_completion.choices[0].message.tool_calls[
|
||||
0].id == 'functions.get_current_weather:0'
|
||||
assert (
|
||||
chat_completion.choices[0].message.tool_calls[0].id
|
||||
== "functions.get_current_weather:0"
|
||||
)
|
||||
else:
|
||||
# Streaming test
|
||||
output_stream = await k2_client.chat.completions.create(
|
||||
@@ -277,42 +255,45 @@ async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
async for chunk in output_stream:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
for o in output:
|
||||
assert o.id is None or o.id == 'functions.get_current_weather:0'
|
||||
assert o.id is None or o.id == "functions.get_current_weather:0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("arguments", ["{}", ''])
|
||||
async def test_no_args_tool_call(client: openai.AsyncOpenAI, model_name: str,
|
||||
arguments: str):
|
||||
@pytest.mark.parametrize("arguments", ["{}", ""])
|
||||
async def test_no_args_tool_call(
|
||||
client: openai.AsyncOpenAI, model_name: str, arguments: str
|
||||
):
|
||||
# Step 1: Define a tool that requires no parameters
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_time",
|
||||
"description":
|
||||
"Get the current date and time. No parameters needed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}, # No parameters
|
||||
"required": [] # No required fields
|
||||
}
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_time",
|
||||
"description": "Get the current date and time. No parameters needed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}, # No parameters
|
||||
"required": [], # No required fields
|
||||
},
|
||||
},
|
||||
}
|
||||
}]
|
||||
]
|
||||
messages = [{"role": "user", "content": "What time is it now?"}]
|
||||
# Step 2: Send user message and let model decide whether to call the tool
|
||||
response = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto" # Let model choose automatically
|
||||
tool_choice="auto", # Let model choose automatically
|
||||
)
|
||||
|
||||
# Step 3: Check if model wants to call a tool
|
||||
@@ -328,11 +309,13 @@ async def test_no_args_tool_call(client: openai.AsyncOpenAI, model_name: str,
|
||||
messages.append(message)
|
||||
current_time = datetime.datetime.now()
|
||||
result = current_time.isoformat()
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": result,
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": result,
|
||||
}
|
||||
)
|
||||
# Step 5: Send tool result back to model to continue conversation
|
||||
final_response = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
|
||||
@@ -9,6 +9,7 @@ import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
|
||||
# downloading lora to test lora requests
|
||||
from openai import BadRequestError
|
||||
from transformers import AutoConfig
|
||||
@@ -23,8 +24,9 @@ CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["use-lora"])
|
||||
def default_server_args(request: pytest.FixtureRequest,
|
||||
opt125_lora_files: str) -> list[str]:
|
||||
def default_server_args(
|
||||
request: pytest.FixtureRequest, opt125_lora_files: str
|
||||
) -> list[str]:
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@@ -42,18 +44,20 @@ def default_server_args(request: pytest.FixtureRequest,
|
||||
lora_module_1 = {
|
||||
"name": LORA_SERVING_MODEL_NAME,
|
||||
"path": opt125_lora_files,
|
||||
"base_model_name": MODEL_NAME
|
||||
"base_model_name": MODEL_NAME,
|
||||
}
|
||||
|
||||
args.extend([
|
||||
"--enable-lora",
|
||||
"--lora-module",
|
||||
json.dumps(lora_module_1),
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
])
|
||||
args.extend(
|
||||
[
|
||||
"--enable-lora",
|
||||
"--lora-module",
|
||||
json.dumps(lora_module_1),
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
]
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
@@ -67,7 +71,7 @@ EXAMPLE_PROMPTS = [
|
||||
def _encode_embeds(embeds: torch.Tensor):
|
||||
buffer = io.BytesIO()
|
||||
torch.save(embeds, buffer)
|
||||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -79,8 +83,7 @@ def example_prompt_embeds(hf_runner):
|
||||
return [_encode_embeds(item) for item in example_embeddings]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=["", "--disable-frontend-multiprocessing"])
|
||||
@pytest.fixture(scope="module", params=["", "--disable-frontend-multiprocessing"])
|
||||
def server_with_prompt_embeds(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
@@ -110,7 +113,8 @@ async def test_completions_with_prompt_embeds(
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
@@ -120,7 +124,8 @@ async def test_completions_with_prompt_embeds(
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
|
||||
)
|
||||
assert len(completion.choices) == 2
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert len(completion.choices[1].text) >= 1
|
||||
@@ -131,7 +136,8 @@ async def test_completions_with_prompt_embeds(
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
|
||||
stream = await client_with_prompt_embeds.completions.create(
|
||||
@@ -140,7 +146,8 @@ async def test_completions_with_prompt_embeds(
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
)
|
||||
chunks = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
@@ -159,12 +166,12 @@ async def test_completions_with_prompt_embeds(
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
|
||||
)
|
||||
chunks_stream_embeds: list[list[str]] = [[], []]
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
chunks_stream_embeds[chunk.choices[0].index].append(
|
||||
chunk.choices[0].text)
|
||||
chunks_stream_embeds[chunk.choices[0].index].append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert finish_reason_count == 2
|
||||
@@ -179,7 +186,8 @@ async def test_completions_with_prompt_embeds(
|
||||
prompt="This is a prompt",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
)
|
||||
assert len(completion.choices) == 2
|
||||
completion_text_only = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
@@ -192,18 +200,18 @@ async def test_completions_with_prompt_embeds(
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
)
|
||||
# Embeddings responses should be handled first
|
||||
assert completion_mixed.choices[0].text == completion_embeds_only.choices[
|
||||
0].text
|
||||
assert completion_mixed.choices[1].text == completion_text_only.choices[
|
||||
0].text
|
||||
assert completion_mixed.choices[0].text == completion_embeds_only.choices[0].text
|
||||
assert completion_mixed.choices[1].text == completion_text_only.choices[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
|
||||
async def test_completions_errors_with_prompt_embeds(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
# Test error case: invalid prompt_embeds
|
||||
with pytest.raises(BadRequestError):
|
||||
await client_with_prompt_embeds.completions.create(
|
||||
@@ -211,7 +219,8 @@ async def test_completions_errors_with_prompt_embeds(
|
||||
model=model_name,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": "invalid_base64"})
|
||||
extra_body={"prompt_embeds": "invalid_base64"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -233,7 +242,8 @@ async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
logprobs=logprobs_arg,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
)
|
||||
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
@@ -252,7 +262,8 @@ async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
logprobs=logprobs_arg,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
|
||||
)
|
||||
|
||||
assert len(completion.choices) == 2
|
||||
for choice in completion.choices:
|
||||
@@ -262,8 +273,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
assert len(logprobs.token_logprobs) == 5
|
||||
assert len(logprobs.top_logprobs) == 5
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) == 5
|
||||
|
||||
|
||||
@@ -280,8 +290,5 @@ async def test_prompt_logprobs_raises_error(
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={
|
||||
"prompt_embeds": encoded_embeds,
|
||||
"prompt_logprobs": True
|
||||
},
|
||||
extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True},
|
||||
)
|
||||
|
||||
@@ -16,8 +16,7 @@ from ...utils import RemoteOpenAIServer
|
||||
# need a multimodal model for these tests.
|
||||
|
||||
# Contains a modality specific lora alongside the base model
|
||||
MULTIMODAL_MODEL_NAME = snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct")
|
||||
MULTIMODAL_MODEL_NAME = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
||||
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")
|
||||
|
||||
ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||
@@ -25,7 +24,6 @@ ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def multimodal_server(): # noqa: F811
|
||||
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@@ -45,11 +43,12 @@ def multimodal_server(): # noqa: F811
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
"--default-mm-loras",
|
||||
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
|
||||
f'{{"audio": "{AUDIO_LORA_PATH}"}}',
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args,
|
||||
max_wait_seconds=480) as remote_server:
|
||||
with RemoteOpenAIServer(
|
||||
MULTIMODAL_MODEL_NAME, args, max_wait_seconds=480
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -70,25 +69,25 @@ async def test_default_mm_lora_chat_completions(
|
||||
multi_modal_client: openai.AsyncOpenAI,
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Can you transcribe this audio?",
|
||||
}, {
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_assets[0].url
|
||||
},
|
||||
}]
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Can you transcribe this audio?",
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {"url": audio_assets[0].url},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
chat_completion = await multi_modal_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=128,
|
||||
temperature=0.0)
|
||||
model=model_name, messages=messages, max_completion_tokens=128, temperature=0.0
|
||||
)
|
||||
|
||||
assert len(chat_completion.choices) > 0
|
||||
|
||||
|
||||
@@ -20,26 +20,18 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
BADREQUEST_CASES = [
|
||||
(
|
||||
"test_rank",
|
||||
{
|
||||
"r": 1024
|
||||
},
|
||||
{"r": 1024},
|
||||
"is greater than max_lora_rank",
|
||||
),
|
||||
(
|
||||
"test_bias",
|
||||
{
|
||||
"bias": "all"
|
||||
},
|
||||
{"bias": "all"},
|
||||
"Adapter bias cannot be used without bias_enabled",
|
||||
),
|
||||
("test_dora", {
|
||||
"use_dora": True
|
||||
}, "does not yet support DoRA"),
|
||||
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
|
||||
(
|
||||
"test_modules_to_save",
|
||||
{
|
||||
"modules_to_save": ["lm_head"]
|
||||
},
|
||||
{"modules_to_save": ["lm_head"]},
|
||||
"only supports modules_to_save being None",
|
||||
),
|
||||
]
|
||||
@@ -48,24 +40,23 @@ BADREQUEST_CASES = [
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[True])
|
||||
def server_with_lora_modules_json(request, monkeypatch_module,
|
||||
zephyr_lora_files):
|
||||
|
||||
def server_with_lora_modules_json(request, monkeypatch_module, zephyr_lora_files):
|
||||
use_v1 = request.param
|
||||
assert use_v1
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1')
|
||||
monkeypatch_module.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
# Define the json format LoRA module configurations
|
||||
lora_module_1 = {
|
||||
"name": "zephyr-lora",
|
||||
"path": zephyr_lora_files,
|
||||
"base_model_name": MODEL_NAME
|
||||
"base_model_name": MODEL_NAME,
|
||||
}
|
||||
|
||||
args = [
|
||||
@@ -96,14 +87,12 @@ def server_with_lora_modules_json(request, monkeypatch_module,
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server_with_lora_modules_json):
|
||||
async with server_with_lora_modules_json.get_async_client(
|
||||
) as async_client:
|
||||
async with server_with_lora_modules_json.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_lora_lineage(client: openai.AsyncOpenAI,
|
||||
zephyr_lora_files):
|
||||
async def test_static_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
@@ -111,22 +100,18 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI,
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert served_model.root == MODEL_NAME
|
||||
assert served_model.parent is None
|
||||
assert all(lora_model.root == zephyr_lora_files
|
||||
for lora_model in lora_models)
|
||||
assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models)
|
||||
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
||||
assert lora_models[0].id == "zephyr-lora"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI,
|
||||
zephyr_lora_files):
|
||||
|
||||
response = await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": "zephyr-lora-3",
|
||||
"lora_path": zephyr_lora_files
|
||||
})
|
||||
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files):
|
||||
response = await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "zephyr-lora-3", "lora_path": zephyr_lora_files},
|
||||
)
|
||||
# Ensure adapter loads before querying /models
|
||||
assert "success" in response
|
||||
|
||||
@@ -141,37 +126,37 @@ async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI):
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": "notfound",
|
||||
"lora_path": "/not/an/adapter"
|
||||
})
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "notfound", "lora_path": "/not/an/adapter"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
|
||||
tmp_path):
|
||||
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, tmp_path):
|
||||
invalid_files = tmp_path / "invalid_files"
|
||||
invalid_files.mkdir()
|
||||
(invalid_files / "adapter_config.json").write_text("this is not json")
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": "invalid-json",
|
||||
"lora_path": str(invalid_files)
|
||||
})
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "invalid-json", "lora_path": str(invalid_files)},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("test_name,config_change,expected_error",
|
||||
BADREQUEST_CASES)
|
||||
async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
|
||||
zephyr_lora_files, test_name: str,
|
||||
config_change: dict,
|
||||
expected_error: str):
|
||||
@pytest.mark.parametrize("test_name,config_change,expected_error", BADREQUEST_CASES)
|
||||
async def test_dynamic_lora_badrequests(
|
||||
client: openai.AsyncOpenAI,
|
||||
tmp_path,
|
||||
zephyr_lora_files,
|
||||
test_name: str,
|
||||
config_change: dict,
|
||||
expected_error: str,
|
||||
):
|
||||
# Create test directory
|
||||
test_dir = tmp_path / test_name
|
||||
|
||||
@@ -191,29 +176,28 @@ async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
|
||||
|
||||
# Test loading the adapter
|
||||
with pytest.raises(openai.BadRequestError, match=expected_error):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": test_name,
|
||||
"lora_path": str(test_dir)
|
||||
})
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": test_name, "lora_path": str(test_dir)},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path,
|
||||
zephyr_lora_files):
|
||||
async def test_multiple_lora_adapters(
|
||||
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files
|
||||
):
|
||||
"""Validate that many loras can be dynamically registered and inferenced
|
||||
with concurrently"""
|
||||
|
||||
# This test file configures the server with --max-cpu-loras=2 and this test
|
||||
# will concurrently load 10 adapters, so it should flex the LRU cache
|
||||
async def load_and_run_adapter(adapter_name: str):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": adapter_name,
|
||||
"lora_path": str(zephyr_lora_files)
|
||||
})
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)},
|
||||
)
|
||||
for _ in range(3):
|
||||
await client.completions.create(
|
||||
model=adapter_name,
|
||||
@@ -223,8 +207,7 @@ async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path,
|
||||
|
||||
lora_tasks = []
|
||||
for i in range(10):
|
||||
lora_tasks.append(
|
||||
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
|
||||
lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
|
||||
|
||||
results, _ = await asyncio.wait(lora_tasks)
|
||||
|
||||
@@ -234,8 +217,8 @@ async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loading_invalid_adapters_does_not_break_others(
|
||||
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files):
|
||||
|
||||
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files
|
||||
):
|
||||
invalid_files = tmp_path / "invalid_files"
|
||||
invalid_files.mkdir()
|
||||
(invalid_files / "adapter_config.json").write_text("this is not json")
|
||||
@@ -266,20 +249,18 @@ async def test_loading_invalid_adapters_does_not_break_others(
|
||||
# Run a bunch of bad adapter loads
|
||||
for _ in range(25):
|
||||
with suppress(openai.NotFoundError):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": "notfound",
|
||||
"lora_path": "/not/an/adapter"
|
||||
})
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "notfound", "lora_path": "/not/an/adapter"},
|
||||
)
|
||||
for _ in range(25):
|
||||
with suppress(openai.BadRequestError):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": "invalid",
|
||||
"lora_path": str(invalid_files)
|
||||
})
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "invalid", "lora_path": str(invalid_files)},
|
||||
)
|
||||
|
||||
# Ensure all the running requests with lora adapters succeeded
|
||||
stop_good_requests_event.set()
|
||||
@@ -288,12 +269,11 @@ async def test_loading_invalid_adapters_does_not_break_others(
|
||||
assert not isinstance(r, Exception), f"Got exception {r}"
|
||||
|
||||
# Ensure we can load another adapter and run it
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": "valid",
|
||||
"lora_path": zephyr_lora_files
|
||||
})
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "valid", "lora_path": zephyr_lora_files},
|
||||
)
|
||||
await client.completions.create(
|
||||
model="valid",
|
||||
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||
@@ -310,12 +290,11 @@ async def test_beam_search_with_lora_adapters(
|
||||
"""Validate that async beam search can be used with lora."""
|
||||
|
||||
async def load_and_run_adapter(adapter_name: str):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": adapter_name,
|
||||
"lora_path": str(zephyr_lora_files)
|
||||
})
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)},
|
||||
)
|
||||
for _ in range(3):
|
||||
await client.completions.create(
|
||||
model=adapter_name,
|
||||
@@ -326,8 +305,7 @@ async def test_beam_search_with_lora_adapters(
|
||||
|
||||
lora_tasks = []
|
||||
for i in range(3):
|
||||
lora_tasks.append(
|
||||
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
|
||||
lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
|
||||
|
||||
results, _ = await asyncio.wait(lora_tasks)
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ import pytest
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
@@ -33,14 +32,14 @@ class MockHFConfig:
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
"""Minimal mock ModelConfig for testing."""
|
||||
|
||||
model: str = MODEL_NAME
|
||||
tokenizer: str = MODEL_NAME
|
||||
trust_remote_code: bool = False
|
||||
tokenizer_mode: str = "auto"
|
||||
max_model_len: int = 100
|
||||
tokenizer_revision: Optional[str] = None
|
||||
multimodal_config: MultiModalConfig = field(
|
||||
default_factory=MultiModalConfig)
|
||||
multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
|
||||
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
@@ -55,17 +54,21 @@ class MockModelConfig:
|
||||
|
||||
|
||||
class MockLoRAResolver(LoRAResolver):
|
||||
|
||||
async def resolve_lora(self, base_model_name: str,
|
||||
lora_name: str) -> Optional[LoRARequest]:
|
||||
async def resolve_lora(
|
||||
self, base_model_name: str, lora_name: str
|
||||
) -> Optional[LoRARequest]:
|
||||
if lora_name == "test-lora":
|
||||
return LoRARequest(lora_name="test-lora",
|
||||
lora_int_id=1,
|
||||
lora_local_path="/fake/path/test-lora")
|
||||
return LoRARequest(
|
||||
lora_name="test-lora",
|
||||
lora_int_id=1,
|
||||
lora_local_path="/fake/path/test-lora",
|
||||
)
|
||||
elif lora_name == "invalid-lora":
|
||||
return LoRARequest(lora_name="invalid-lora",
|
||||
lora_int_id=2,
|
||||
lora_local_path="/fake/path/invalid-lora")
|
||||
return LoRARequest(
|
||||
lora_name="invalid-lora",
|
||||
lora_int_id=2,
|
||||
lora_local_path="/fake/path/invalid-lora",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -96,8 +99,7 @@ def mock_serving_setup():
|
||||
return True
|
||||
if lora_request.lora_name == "invalid-lora":
|
||||
# Simulate failure during addition (e.g. invalid format)
|
||||
raise ValueError(f"Simulated failure adding LoRA: "
|
||||
f"{lora_request.lora_name}")
|
||||
raise ValueError(f"Simulated failure adding LoRA: {lora_request.lora_name}")
|
||||
return True
|
||||
|
||||
mock_engine.add_lora = AsyncMock(side_effect=mock_add_lora_side_effect)
|
||||
@@ -106,31 +108,31 @@ def mock_serving_setup():
|
||||
for _ in []:
|
||||
yield _
|
||||
|
||||
mock_engine.generate = MagicMock(spec=AsyncLLM.generate,
|
||||
side_effect=mock_generate)
|
||||
mock_engine.generate = MagicMock(spec=AsyncLLM.generate, side_effect=mock_generate)
|
||||
|
||||
mock_engine.generate.reset_mock()
|
||||
mock_engine.add_lora.reset_mock()
|
||||
|
||||
mock_model_config = MockModelConfig()
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
models = OpenAIServingModels(
|
||||
engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config,
|
||||
)
|
||||
|
||||
serving_completion = OpenAIServingCompletion(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
request_logger=None)
|
||||
serving_completion = OpenAIServingCompletion(
|
||||
mock_engine, mock_model_config, models, request_logger=None
|
||||
)
|
||||
|
||||
serving_completion._process_inputs = AsyncMock(return_value=(MagicMock(
|
||||
name="engine_request"), {}))
|
||||
serving_completion._process_inputs = AsyncMock(
|
||||
return_value=(MagicMock(name="engine_request"), {})
|
||||
)
|
||||
|
||||
return mock_engine, serving_completion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_completion_with_lora_resolver(mock_serving_setup,
|
||||
monkeypatch):
|
||||
async def test_serving_completion_with_lora_resolver(mock_serving_setup, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
|
||||
|
||||
mock_engine, serving_completion = mock_serving_setup
|
||||
@@ -152,14 +154,13 @@ async def test_serving_completion_with_lora_resolver(mock_serving_setup,
|
||||
assert called_lora_request.lora_name == lora_model_name
|
||||
|
||||
mock_engine.generate.assert_called_once()
|
||||
called_lora_request = mock_engine.generate.call_args[1]['lora_request']
|
||||
called_lora_request = mock_engine.generate.call_args[1]["lora_request"]
|
||||
assert isinstance(called_lora_request, LoRARequest)
|
||||
assert called_lora_request.lora_name == lora_model_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_completion_resolver_not_found(mock_serving_setup,
|
||||
monkeypatch):
|
||||
async def test_serving_completion_resolver_not_found(mock_serving_setup, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
|
||||
|
||||
mock_engine, serving_completion = mock_serving_setup
|
||||
@@ -182,7 +183,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_completion_resolver_add_lora_fails(
|
||||
mock_serving_setup, monkeypatch):
|
||||
mock_serving_setup, monkeypatch
|
||||
):
|
||||
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
|
||||
|
||||
mock_engine, serving_completion = mock_serving_setup
|
||||
|
||||
@@ -54,19 +54,22 @@ def default_server_args():
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=[
|
||||
"",
|
||||
"--enable-chunked-prefill",
|
||||
"--disable-frontend-multiprocessing",
|
||||
f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}",
|
||||
])
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=[
|
||||
"",
|
||||
"--enable-chunked-prefill",
|
||||
"--disable-frontend-multiprocessing",
|
||||
f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}",
|
||||
],
|
||||
)
|
||||
def server(use_v1, default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
env_dict = dict(VLLM_USE_V1='1' if use_v1 else '0')
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args,
|
||||
env_dict=env_dict) as remote_server:
|
||||
env_dict = dict(VLLM_USE_V1="1" if use_v1 else "0")
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, default_server_args, env_dict=env_dict
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -87,30 +90,36 @@ _NUM_GENERATION_TOKENS_PER_REQUEST = 10
|
||||
# {metric_family: [(suffix, expected_value)]}
|
||||
EXPECTED_VALUES = {
|
||||
"vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:time_per_output_token_seconds":
|
||||
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
|
||||
"vllm:time_per_output_token_seconds": [
|
||||
("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))
|
||||
],
|
||||
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_prompt_tokens":
|
||||
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_generation_tokens":
|
||||
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_prompt_tokens": [
|
||||
("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS),
|
||||
],
|
||||
"vllm:request_generation_tokens": [
|
||||
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS),
|
||||
],
|
||||
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_params_max_tokens": [
|
||||
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS)
|
||||
("_count", _NUM_REQUESTS),
|
||||
],
|
||||
"vllm:iteration_tokens_total":
|
||||
[("_sum", _NUM_REQUESTS *
|
||||
(_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST)),
|
||||
("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST)],
|
||||
"vllm:prompt_tokens": [("_total",
|
||||
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
|
||||
"vllm:iteration_tokens_total": [
|
||||
(
|
||||
"_sum",
|
||||
_NUM_REQUESTS
|
||||
* (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||
),
|
||||
("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||
],
|
||||
"vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
|
||||
"vllm:generation_tokens": [
|
||||
("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)
|
||||
],
|
||||
@@ -119,14 +128,16 @@ EXPECTED_VALUES = {
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_counts(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient, use_v1: bool):
|
||||
async def test_metrics_counts(
|
||||
server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool
|
||||
):
|
||||
for _ in range(_NUM_REQUESTS):
|
||||
# sending a request triggers the metrics to be logged.
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=_TOKENIZED_PROMPT,
|
||||
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)
|
||||
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST,
|
||||
)
|
||||
|
||||
response = requests.get(server.url_for("metrics"))
|
||||
print(response.text)
|
||||
@@ -134,9 +145,10 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
|
||||
|
||||
# Loop over all expected metric_families
|
||||
for metric_family, suffix_values_list in EXPECTED_VALUES.items():
|
||||
if ((use_v1 and metric_family not in EXPECTED_METRICS_V1)
|
||||
or (not server.show_hidden_metrics
|
||||
and metric_family in HIDDEN_DEPRECATED_METRICS)):
|
||||
if (use_v1 and metric_family not in EXPECTED_METRICS_V1) or (
|
||||
not server.show_hidden_metrics
|
||||
and metric_family in HIDDEN_DEPRECATED_METRICS
|
||||
):
|
||||
continue
|
||||
|
||||
found_metric = False
|
||||
@@ -160,14 +172,15 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
|
||||
assert sample.value == expected_value, (
|
||||
f"{metric_name_w_suffix} expected value of "
|
||||
f"{expected_value} did not match found value "
|
||||
f"{sample.value}")
|
||||
f"{sample.value}"
|
||||
)
|
||||
break
|
||||
assert found_suffix, (
|
||||
f"Did not find {metric_name_w_suffix} in prom endpoint"
|
||||
)
|
||||
break
|
||||
|
||||
assert found_metric, (f"Did not find {metric_family} in prom endpoint")
|
||||
assert found_metric, f"Did not find {metric_family} in prom endpoint"
|
||||
|
||||
|
||||
EXPECTED_METRICS = [
|
||||
@@ -290,30 +303,30 @@ HIDDEN_DEPRECATED_METRICS: list[str] = [
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_exist(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient, use_v1: bool):
|
||||
async def test_metrics_exist(
|
||||
server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool
|
||||
):
|
||||
# sending a request triggers the metrics to be logged.
|
||||
await client.completions.create(model=MODEL_NAME,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME, prompt="Hello, my name is", max_tokens=5, temperature=0.0
|
||||
)
|
||||
|
||||
response = requests.get(server.url_for("metrics"))
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS):
|
||||
if (metric in HIDDEN_DEPRECATED_METRICS
|
||||
and not server.show_hidden_metrics):
|
||||
for metric in EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS:
|
||||
if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics:
|
||||
continue
|
||||
assert metric in response.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_metrics_reset(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient, use_v1: bool):
|
||||
|
||||
running_requests, waiting_requests, kv_cache_usage = (
|
||||
_get_running_metrics_from_api(server, use_v1))
|
||||
async def test_abort_metrics_reset(
|
||||
server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool
|
||||
):
|
||||
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
|
||||
server, use_v1
|
||||
)
|
||||
|
||||
# Expect no running requests or kvcache usage
|
||||
assert running_requests == 0
|
||||
@@ -328,15 +341,18 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer,
|
||||
model=MODEL_NAME,
|
||||
prompt=_TOKENIZED_PROMPT,
|
||||
max_tokens=100, # Long generation to give time to abort
|
||||
temperature=0.0))
|
||||
temperature=0.0,
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait a bit for requests to start processing
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Check that we have running requests
|
||||
running_requests, waiting_requests, kv_cache_usage = (
|
||||
_get_running_metrics_from_api(server, use_v1))
|
||||
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
|
||||
server, use_v1
|
||||
)
|
||||
|
||||
# Expect running requests and kvcache usage
|
||||
assert running_requests > 0
|
||||
@@ -355,17 +371,18 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer,
|
||||
|
||||
# Verify running and waiting requests counts and KV cache usage are zero
|
||||
running_requests_after, waiting_requests_after, kv_cache_usage_after = (
|
||||
_get_running_metrics_from_api(server, use_v1))
|
||||
_get_running_metrics_from_api(server, use_v1)
|
||||
)
|
||||
|
||||
assert running_requests_after == 0,\
|
||||
(f"Expected 0 running requests after abort, got "
|
||||
f"{running_requests_after}")
|
||||
assert waiting_requests_after == 0,\
|
||||
(f"Expected 0 waiting requests after abort, got "
|
||||
f"{waiting_requests_after}")
|
||||
assert kv_cache_usage_after == 0,\
|
||||
(f"Expected 0% KV cache usage after abort, got "
|
||||
f"{kv_cache_usage_after}")
|
||||
assert running_requests_after == 0, (
|
||||
f"Expected 0 running requests after abort, got {running_requests_after}"
|
||||
)
|
||||
assert waiting_requests_after == 0, (
|
||||
f"Expected 0 waiting requests after abort, got {waiting_requests_after}"
|
||||
)
|
||||
assert kv_cache_usage_after == 0, (
|
||||
f"Expected 0% KV cache usage after abort, got {kv_cache_usage_after}"
|
||||
)
|
||||
|
||||
|
||||
def _get_running_metrics_from_api(server: RemoteOpenAIServer, use_v1: bool):
|
||||
@@ -377,8 +394,9 @@ def _get_running_metrics_from_api(server: RemoteOpenAIServer, use_v1: bool):
|
||||
# Verify running and waiting requests counts and KV cache usage are zero
|
||||
running_requests, waiting_requests, kv_cache_usage = None, None, None
|
||||
|
||||
kv_cache_usage_metric = ("vllm:kv_cache_usage_perc"
|
||||
if use_v1 else "vllm:gpu_cache_usage_perc")
|
||||
kv_cache_usage_metric = (
|
||||
"vllm:kv_cache_usage_perc" if use_v1 else "vllm:gpu_cache_usage_perc"
|
||||
)
|
||||
|
||||
for family in text_string_to_metric_families(response.text):
|
||||
if family.name == "vllm:num_requests_running":
|
||||
@@ -411,28 +429,31 @@ def test_metrics_exist_run_batch(use_v1: bool):
|
||||
port = "8001"
|
||||
server_url = f"http://{base_url}:{port}"
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
with (
|
||||
tempfile.NamedTemporaryFile("w") as input_file,
|
||||
tempfile.NamedTemporaryFile("r") as output_file,
|
||||
):
|
||||
input_file.write(input_batch)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.run_batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"intfloat/multilingual-e5-small",
|
||||
"--enable-metrics",
|
||||
"--url",
|
||||
base_url,
|
||||
"--port",
|
||||
port,
|
||||
],
|
||||
env={"VLLM_USE_V1": "1"})
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.run_batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"intfloat/multilingual-e5-small",
|
||||
"--enable-metrics",
|
||||
"--url",
|
||||
base_url,
|
||||
"--port",
|
||||
port,
|
||||
],
|
||||
env={"VLLM_USE_V1": "1"},
|
||||
)
|
||||
|
||||
def is_server_up(url):
|
||||
try:
|
||||
|
||||
@@ -52,6 +52,5 @@ async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
|
||||
lora_models = models[1:]
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert served_model.root == MODEL_NAME
|
||||
assert all(lora_model.root == zephyr_lora_files
|
||||
for lora_model in lora_models)
|
||||
assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models)
|
||||
assert lora_models[0].id == "zephyr-lora"
|
||||
|
||||
@@ -25,13 +25,10 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
|
||||
client = server.get_client()
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}],
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
generated_text = completion.choices[0].message.content
|
||||
|
||||
@@ -75,10 +75,11 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
|
||||
http://localhost:8000/v1/chat/completions
|
||||
""" # noqa: E501
|
||||
if hasattr(case, "body") and isinstance(case.body, dict):
|
||||
if ("messages" in case.body
|
||||
and isinstance(case.body["messages"], list)
|
||||
and len(case.body["messages"]) > 0):
|
||||
|
||||
if (
|
||||
"messages" in case.body
|
||||
and isinstance(case.body["messages"], list)
|
||||
and len(case.body["messages"]) > 0
|
||||
):
|
||||
for message in case.body["messages"]:
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
@@ -86,10 +87,11 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
|
||||
# Check for invalid file type in tokenize endpoint
|
||||
if op.method.lower() == "post" and op.path == "/tokenize":
|
||||
content = message.get("content", [])
|
||||
if (isinstance(content, list) and len(content) > 0
|
||||
and any(
|
||||
item.get("type") == "file"
|
||||
for item in content)):
|
||||
if (
|
||||
isinstance(content, list)
|
||||
and len(content) > 0
|
||||
and any(item.get("type") == "file" for item in content)
|
||||
):
|
||||
return False
|
||||
|
||||
# Check for invalid tool_calls with non-function types
|
||||
@@ -106,10 +108,13 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
|
||||
# Causing a server error in EBNF grammar parsing
|
||||
# https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421
|
||||
structured_outputs = case.body.get("structured_outputs", {})
|
||||
grammar = structured_outputs.get("grammar") if isinstance(
|
||||
structured_outputs, dict) else None
|
||||
grammar = (
|
||||
structured_outputs.get("grammar")
|
||||
if isinstance(structured_outputs, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
if grammar == '':
|
||||
if grammar == "":
|
||||
# Allow None (will be handled as no grammar)
|
||||
# But skip empty strings
|
||||
return False
|
||||
@@ -133,9 +138,8 @@ def test_openapi_stateless(case: schemathesis.Case):
|
||||
|
||||
timeout = {
|
||||
# requires a longer timeout
|
||||
("POST", "/v1/chat/completions"):
|
||||
LONG_TIMEOUT_SECONDS,
|
||||
("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS,
|
||||
}.get(key, DEFAULT_TIMEOUT_SECONDS)
|
||||
|
||||
#No need to verify SSL certificate for localhost
|
||||
# No need to verify SSL certificate for localhost
|
||||
case.call_and_validate(verify=False, timeout=timeout)
|
||||
|
||||
@@ -37,7 +37,7 @@ def server(request: pytest.FixtureRequest):
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"2",
|
||||
*passed_params
|
||||
*passed_params,
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
@@ -73,8 +73,9 @@ async def test_missing_api_token(server: RemoteOpenAIServer):
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_passed_api_token(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("v1/models"),
|
||||
headers={"Authorization": "Bearer test"})
|
||||
response = requests.get(
|
||||
server.url_for("v1/models"), headers={"Authorization": "Bearer test"}
|
||||
)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
|
||||
@@ -110,7 +111,8 @@ async def test_enable_request_id_header(server: RemoteOpenAIServer):
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_request_id_header(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("health"),
|
||||
headers={"X-Request-Id": "Custom"})
|
||||
response = requests.get(
|
||||
server.url_for("health"), headers={"X-Request-Id": "Custom"}
|
||||
)
|
||||
assert "X-Request-Id" in response.headers
|
||||
assert response.headers.get("X-Request-Id") == "Custom"
|
||||
|
||||
@@ -17,7 +17,7 @@ from ...utils import RemoteOpenAIServer
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v1_only(monkeypatch):
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -28,15 +28,16 @@ async def test_empty_prompt():
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
with pytest.raises(
|
||||
openai.BadRequestError,
|
||||
match=
|
||||
"Either prompt or prompt_embeds must be provided and non-empty."
|
||||
openai.BadRequestError,
|
||||
match="Either prompt or prompt_embeds must be provided and non-empty.",
|
||||
):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": []})
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": []},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -46,23 +47,23 @@ async def test_out_of_vocab_token_ids():
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
with pytest.raises(openai.BadRequestError,
|
||||
match=re.compile('.*out of vocabulary.*').pattern):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=[999999],
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
with pytest.raises(
|
||||
openai.BadRequestError, match=re.compile(".*out of vocabulary.*").pattern
|
||||
):
|
||||
await client.completions.create(
|
||||
model=model_name, prompt=[999999], max_tokens=5, temperature=0.0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"layout",
|
||||
[torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr])
|
||||
"layout", [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr]
|
||||
)
|
||||
@pytest.mark.parametrize("seq_len", [2, 10])
|
||||
@pytest.mark.parametrize("hidden_size", [2, 10])
|
||||
def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout,
|
||||
seq_len: int, hidden_size: int):
|
||||
def test_load_prompt_embeds(
|
||||
dtype: torch.dtype, layout: torch.layout, seq_len: int, hidden_size: int
|
||||
):
|
||||
# construct arbitrary tensors of various dtypes, layouts, and sizes.
|
||||
# We need to check against different layouts to make sure that if a user
|
||||
# uses sparse tensors to reduce the transmission size of prompt embeddings,
|
||||
@@ -92,6 +93,6 @@ def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout,
|
||||
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
|
||||
assert loaded_tensor.device.type == "cpu"
|
||||
assert loaded_tensor.layout == torch.strided
|
||||
torch.testing.assert_close(loaded_tensor,
|
||||
tensor.to("cpu").to_dense(),
|
||||
equal_nan=True)
|
||||
torch.testing.assert_close(
|
||||
loaded_tensor, tensor.to("cpu").to_dense(), equal_nan=True
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ MODEL_NAME = "openai/gpt-oss-20b"
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
@@ -36,8 +37,7 @@ def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
|
||||
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
|
||||
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
|
||||
"code_interpreter,container")
|
||||
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container")
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@@ -57,23 +57,26 @@ async def mcp_enabled_client(mcp_enabled_server):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
|
||||
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI,
|
||||
model_name: str):
|
||||
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: str):
|
||||
response = await mcp_enabled_client.responses.create(
|
||||
model=model_name,
|
||||
# TODO: Ideally should be able to set max tool calls
|
||||
# to prevent multi-turn, but it is not currently supported
|
||||
# would speed up the test
|
||||
input=("What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."),
|
||||
tools=[{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888"
|
||||
}],
|
||||
input=(
|
||||
"What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."
|
||||
),
|
||||
tools=[
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888",
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
@@ -83,23 +86,26 @@ async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
|
||||
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI,
|
||||
model_name: str):
|
||||
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str):
|
||||
response = await mcp_disabled_client.responses.create(
|
||||
model=model_name,
|
||||
# TODO: Ideally should be able to set max tool calls
|
||||
# to prevent multi-turn, but it is not currently supported
|
||||
# would speed up the test
|
||||
input=("What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."),
|
||||
tools=[{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888"
|
||||
}],
|
||||
input=(
|
||||
"What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."
|
||||
),
|
||||
tools=[
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888",
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
|
||||
@@ -17,6 +17,7 @@ MODEL_NAME = "openai/gpt-oss-20b"
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
@@ -94,22 +95,10 @@ async def test_chat(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Respond in Korean."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 13 * 24? Explain your answer."
|
||||
},
|
||||
{"role": "system", "content": "Respond in Korean."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hello! How can I help you today?"},
|
||||
{"role": "user", "content": "What is 13 * 24? Explain your answer."},
|
||||
],
|
||||
)
|
||||
assert response is not None
|
||||
@@ -124,10 +113,7 @@ async def test_chat_with_input_type(client: OpenAI, model_name: str):
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "input_text",
|
||||
"text": "What is 13*24?"
|
||||
}],
|
||||
"content": [{"type": "input_text", "text": "What is 13*24?"}],
|
||||
},
|
||||
],
|
||||
)
|
||||
@@ -141,14 +127,10 @@ async def test_structured_output(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Extract the event information."
|
||||
},
|
||||
{"role": "system", "content": "Extract the event information."},
|
||||
{
|
||||
"role": "user",
|
||||
"content":
|
||||
"Alice and Bob are going to a science fair on Friday.",
|
||||
"content": "Alice and Bob are going to a science fair on Friday.",
|
||||
},
|
||||
],
|
||||
text={
|
||||
@@ -158,18 +140,9 @@ async def test_structured_output(client: OpenAI, model_name: str):
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"date": {
|
||||
"type": "string"
|
||||
},
|
||||
"participants": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"name": {"type": "string"},
|
||||
"date": {"type": "string"},
|
||||
"participants": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["name", "date", "participants"],
|
||||
"additionalProperties": False,
|
||||
@@ -319,11 +292,10 @@ async def test_streaming_types(client: OpenAI, model_name: str):
|
||||
|
||||
stack_of_event_types = []
|
||||
async for event in response:
|
||||
if event.type == 'response.created':
|
||||
if event.type == "response.created":
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type == 'response.completed':
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[
|
||||
event.type]
|
||||
elif event.type == "response.completed":
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
|
||||
stack_of_event_types.pop()
|
||||
if event.type.endswith("added"):
|
||||
stack_of_event_types.append(event.type)
|
||||
@@ -332,8 +304,7 @@ async def test_streaming_types(client: OpenAI, model_name: str):
|
||||
continue
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type.endswith("done"):
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[
|
||||
event.type]
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
|
||||
stack_of_event_types.pop()
|
||||
assert len(stack_of_event_types) == 0
|
||||
|
||||
@@ -381,11 +352,12 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
|
||||
|
||||
# test vllm custom types are in the response
|
||||
if event.type in [
|
||||
"response.completed", "response.in_progress",
|
||||
"response.created"
|
||||
"response.completed",
|
||||
"response.in_progress",
|
||||
"response.created",
|
||||
]:
|
||||
assert 'input_messages' in event.response.model_extra
|
||||
assert 'output_messages' in event.response.model_extra
|
||||
assert "input_messages" in event.response.model_extra
|
||||
assert "output_messages" in event.response.model_extra
|
||||
|
||||
if current_event_mode != event.type:
|
||||
current_event_mode = event.type
|
||||
@@ -396,21 +368,21 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
|
||||
assert event.item.id != current_item_id
|
||||
current_item_id = event.item.id
|
||||
elif event.type in [
|
||||
"response.output_text.delta",
|
||||
"response.reasoning_text.delta"
|
||||
"response.output_text.delta",
|
||||
"response.reasoning_text.delta",
|
||||
]:
|
||||
assert event.item_id == current_item_id
|
||||
|
||||
# verify content_index_id is correct
|
||||
if event.type in [
|
||||
"response.content_part.added",
|
||||
"response.reasoning_part.added"
|
||||
"response.content_part.added",
|
||||
"response.reasoning_part.added",
|
||||
]:
|
||||
assert event.content_index != current_content_index
|
||||
current_content_index = event.content_index
|
||||
elif event.type in [
|
||||
"response.output_text.delta",
|
||||
"response.reasoning_text.delta"
|
||||
"response.output_text.delta",
|
||||
"response.reasoning_text.delta",
|
||||
]:
|
||||
assert event.content_index == current_content_index
|
||||
|
||||
@@ -420,8 +392,10 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
|
||||
print(f"{event.delta}", end="", flush=True)
|
||||
elif "response.code_interpreter_call_code.done" in event.type:
|
||||
print(f"Code: {event.code}", end="", flush=True)
|
||||
elif ("response.output_item.added" in event.type
|
||||
and event.item.type == "web_search_call"):
|
||||
elif (
|
||||
"response.output_item.added" in event.type
|
||||
and event.item.type == "web_search_call"
|
||||
):
|
||||
print(f"Web search: {event.item.action}", end="", flush=True)
|
||||
events.append(event)
|
||||
|
||||
@@ -432,9 +406,8 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
|
||||
if background:
|
||||
starting_after = 5
|
||||
async with await client.responses.retrieve(
|
||||
response_id=resp_id,
|
||||
stream=True,
|
||||
starting_after=starting_after) as stream:
|
||||
response_id=resp_id, stream=True, starting_after=starting_after
|
||||
) as stream:
|
||||
counter = starting_after
|
||||
async for event in stream:
|
||||
counter += 1
|
||||
@@ -448,9 +421,7 @@ async def test_web_search(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Who is the president of South Korea as of now?",
|
||||
tools=[{
|
||||
"type": "web_search_preview"
|
||||
}],
|
||||
tools=[{"type": "web_search_preview"}],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
@@ -465,16 +436,13 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
|
||||
# TODO: Ideally should be able to set max tool calls
|
||||
# to prevent multi-turn, but it is not currently supported
|
||||
# would speed up the test
|
||||
input=("What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."),
|
||||
tools=[{
|
||||
"type": "code_interpreter",
|
||||
"container": {
|
||||
"type": "auto"
|
||||
}
|
||||
}],
|
||||
input=(
|
||||
"What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."
|
||||
),
|
||||
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
@@ -505,26 +473,23 @@ def call_function(name, args):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling(client: OpenAI, model_name: str):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description":
|
||||
"Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {
|
||||
"type": "number"
|
||||
},
|
||||
"longitude": {
|
||||
"type": "number"
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
}]
|
||||
"strict": True,
|
||||
}
|
||||
]
|
||||
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
@@ -547,11 +512,13 @@ async def test_function_calling(client: OpenAI, model_name: str):
|
||||
|
||||
response_2 = await client.responses.create(
|
||||
model=model_name,
|
||||
input=[{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call.call_id,
|
||||
"output": str(result),
|
||||
}],
|
||||
input=[
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call.call_id,
|
||||
"output": str(result),
|
||||
}
|
||||
],
|
||||
tools=tools,
|
||||
previous_response_id=response.id,
|
||||
)
|
||||
@@ -591,17 +558,12 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description":
|
||||
"Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"description": "Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {
|
||||
"type": "number"
|
||||
},
|
||||
"longitude": {
|
||||
"type": "number"
|
||||
},
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
@@ -612,8 +574,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
|
||||
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=
|
||||
"Help me plan a trip to a random place. And tell me the weather there.",
|
||||
input="Help me plan a trip to a random place. And tell me the weather there.",
|
||||
tools=tools,
|
||||
)
|
||||
assert response is not None
|
||||
@@ -630,11 +591,13 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
|
||||
|
||||
response_2 = await client.responses.create(
|
||||
model=model_name,
|
||||
input=[{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call.call_id,
|
||||
"output": str(result),
|
||||
}],
|
||||
input=[
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call.call_id,
|
||||
"output": str(result),
|
||||
}
|
||||
],
|
||||
tools=tools,
|
||||
previous_response_id=response.id,
|
||||
)
|
||||
@@ -652,11 +615,13 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
|
||||
|
||||
response_3 = await client.responses.create(
|
||||
model=model_name,
|
||||
input=[{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call.call_id,
|
||||
"output": str(result),
|
||||
}],
|
||||
input=[
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call.call_id,
|
||||
"output": str(result),
|
||||
}
|
||||
],
|
||||
tools=tools,
|
||||
previous_response_id=response_2.id,
|
||||
)
|
||||
@@ -668,26 +633,23 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling_required(client: OpenAI, model_name: str):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description":
|
||||
"Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {
|
||||
"type": "number"
|
||||
},
|
||||
"longitude": {
|
||||
"type": "number"
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
}]
|
||||
"strict": True,
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.responses.create(
|
||||
@@ -717,31 +679,27 @@ async def test_system_message_with_tools(client: OpenAI, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling_full_history(client: OpenAI, model_name: str):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description":
|
||||
"Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {
|
||||
"type": "number"
|
||||
},
|
||||
"longitude": {
|
||||
"type": "number"
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
}]
|
||||
"strict": True,
|
||||
}
|
||||
]
|
||||
|
||||
input_messages = [{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris today?"
|
||||
}]
|
||||
input_messages = [
|
||||
{"role": "user", "content": "What's the weather like in Paris today?"}
|
||||
]
|
||||
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
@@ -758,8 +716,7 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str):
|
||||
|
||||
result = call_function(name, args)
|
||||
|
||||
input_messages.extend(
|
||||
response.output) # append model's function call message
|
||||
input_messages.extend(response.output) # append model's function call message
|
||||
input_messages.append(
|
||||
{ # append result message
|
||||
"type": "function_call_output",
|
||||
@@ -780,12 +737,12 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_output_messages_enabled(client: OpenAI, model_name: str,
|
||||
server):
|
||||
async def test_output_messages_enabled(client: OpenAI, model_name: str, server):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="What is the capital of South Korea?",
|
||||
extra_body={"enable_response_messages": True})
|
||||
extra_body={"enable_response_messages": True},
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
|
||||
@@ -50,13 +50,16 @@ async def test_basic_completion_with_emoji(server):
|
||||
# Check against the expected prompt token IDs
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
encoded_tokens = tokenizer.encode(
|
||||
"Complete this sentence with emojis: I love coding 🚀")
|
||||
"Complete this sentence with emojis: I love coding 🚀"
|
||||
)
|
||||
# Check that encoded_tokens is a subsequence of prompt_token_ids
|
||||
assert any(completion.choices[0].prompt_token_ids[i:i +
|
||||
len(encoded_tokens)]
|
||||
== encoded_tokens for i in range(
|
||||
len(completion.choices[0].prompt_token_ids) -
|
||||
len(encoded_tokens) + 1))
|
||||
assert any(
|
||||
completion.choices[0].prompt_token_ids[i : i + len(encoded_tokens)]
|
||||
== encoded_tokens
|
||||
for i in range(
|
||||
len(completion.choices[0].prompt_token_ids) - len(encoded_tokens) + 1
|
||||
)
|
||||
)
|
||||
|
||||
# Verify token_ids field is present in the choice
|
||||
assert completion.choices[0].token_ids is not None
|
||||
@@ -86,44 +89,38 @@ async def test_basic_completion_with_emoji(server):
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_use(server):
|
||||
"""Test chat completion with tool use (get_weather function)."""
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The unit of temperature",
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The unit of temperature",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
async with server.get_async_client() as client:
|
||||
# Test with return_token_ids enabled
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris?"
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in Paris?"},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
@@ -145,10 +142,11 @@ async def test_chat_completion_with_tool_use(server):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
prompt_text = tokenizer.decode(response.prompt_token_ids)
|
||||
assert prompt_text.startswith(
|
||||
"<|im_start|>system\nYou are a helpful assistant.")
|
||||
"<|im_start|>system\nYou are a helpful assistant."
|
||||
)
|
||||
assert prompt_text.endswith(
|
||||
"What's the weather like in Paris?<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
"What's the weather like in Paris?<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
response_text = tokenizer.decode(response.choices[0].token_ids)
|
||||
assert response_text.startswith('<tool_call>\n{"name": "get_weather"')
|
||||
@@ -164,14 +162,8 @@ async def test_chat_completion_with_tool_use(server):
|
||||
response_without = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris?"
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather like in Paris?"},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
@@ -203,7 +195,7 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server):
|
||||
extra_body={
|
||||
"return_token_ids": True,
|
||||
"return_tokens_as_token_ids": True,
|
||||
"prompt_logprobs": 1
|
||||
"prompt_logprobs": 1,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -228,16 +220,17 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server):
|
||||
# The prompt_token_ids should match the prompt portion
|
||||
assert len(completion.choices[0].token_ids) < len(logprobs_token_ids)
|
||||
response_token_ids_length = len(completion.choices[0].token_ids)
|
||||
assert logprobs_token_ids[-response_token_ids_length:] == \
|
||||
completion.choices[0].token_ids
|
||||
assert (
|
||||
logprobs_token_ids[-response_token_ids_length:]
|
||||
== completion.choices[0].token_ids
|
||||
)
|
||||
|
||||
# Verify tokenizer consistency
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Decode prompt tokens
|
||||
if completion.choices[0].prompt_token_ids:
|
||||
prompt_text = tokenizer.decode(
|
||||
completion.choices[0].prompt_token_ids)
|
||||
prompt_text = tokenizer.decode(completion.choices[0].prompt_token_ids)
|
||||
# The decoded prompt should match or close to original prompt
|
||||
assert "Hello, world" in prompt_text
|
||||
|
||||
@@ -255,10 +248,7 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server):
|
||||
stream=True,
|
||||
echo=False,
|
||||
logprobs=1,
|
||||
extra_body={
|
||||
"return_token_ids": True,
|
||||
"return_tokens_as_token_ids": True
|
||||
},
|
||||
extra_body={"return_token_ids": True, "return_tokens_as_token_ids": True},
|
||||
)
|
||||
|
||||
# Collect streamed tokens
|
||||
@@ -287,14 +277,8 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server):
|
||||
async def test_chat_completion_with_emoji_and_token_ids(server):
|
||||
"""Test chat completion with emojis to verify token_ids handling."""
|
||||
chat_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You like to use emojis in your responses."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Repeat after me: I love cats 🐱"
|
||||
},
|
||||
{"role": "system", "content": "You like to use emojis in your responses."},
|
||||
{"role": "user", "content": "Repeat after me: I love cats 🐱"},
|
||||
]
|
||||
async with server.get_async_client() as client:
|
||||
response = await client.chat.completions.create(
|
||||
@@ -319,15 +303,16 @@ async def test_chat_completion_with_emoji_and_token_ids(server):
|
||||
|
||||
decoded_prompt = tokenizer.decode(response.prompt_token_ids)
|
||||
assert decoded_prompt.startswith(
|
||||
"<|im_start|>system\nYou like to use emojis in your responses.")
|
||||
"<|im_start|>system\nYou like to use emojis in your responses."
|
||||
)
|
||||
assert decoded_prompt.endswith(
|
||||
"I love cats 🐱<|im_end|>\n<|im_start|>assistant\n")
|
||||
"I love cats 🐱<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
decoded_response = tokenizer.decode(response.choices[0].token_ids)
|
||||
# The content should match the response text
|
||||
# except the ending <|im_end|>
|
||||
assert decoded_response == response.choices[
|
||||
0].message.content + "<|im_end|>"
|
||||
assert decoded_response == response.choices[0].message.content + "<|im_end|>"
|
||||
|
||||
# Test with streaming
|
||||
stream = await client.chat.completions.create(
|
||||
@@ -348,14 +333,14 @@ async def test_chat_completion_with_emoji_and_token_ids(server):
|
||||
assert chunk.prompt_token_ids is not None
|
||||
assert isinstance(chunk.prompt_token_ids, list)
|
||||
# Check the prompt_token_ids match the initial prompt
|
||||
decoded_prompt_stream = tokenizer.decode(
|
||||
chunk.prompt_token_ids)
|
||||
decoded_prompt_stream = tokenizer.decode(chunk.prompt_token_ids)
|
||||
assert decoded_prompt_stream == decoded_prompt
|
||||
first_chunk = False
|
||||
else:
|
||||
chunk_dump = chunk.model_dump()
|
||||
assert "prompt_token_ids" not in chunk_dump, \
|
||||
assert "prompt_token_ids" not in chunk_dump, (
|
||||
"Subsequent chunks should not have prompt_token_ids"
|
||||
)
|
||||
|
||||
if chunk.choices:
|
||||
if chunk.choices[0].delta.content:
|
||||
|
||||
@@ -44,22 +44,19 @@ def server_fixture(request, default_server_args): # noqa: F811
|
||||
with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
|
||||
yield (remote_server, True)
|
||||
else:
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
default_server_args) as remote_server:
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield (remote_server, False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("server_fixture", [True, False], indirect=True)
|
||||
async def test_completion_return_tokens_as_token_ids_completion(
|
||||
server_fixture):
|
||||
async def test_completion_return_tokens_as_token_ids_completion(server_fixture):
|
||||
server, use_server_flag = server_fixture
|
||||
request_args = {}
|
||||
if not use_server_flag:
|
||||
request_args["return_tokens_as_token_ids"] = True
|
||||
|
||||
async with server.get_async_client() as client:
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
@@ -70,7 +67,8 @@ async def test_completion_return_tokens_as_token_ids_completion(
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
logprobs=1,
|
||||
extra_body=request_args)
|
||||
extra_body=request_args,
|
||||
)
|
||||
|
||||
text = completion.choices[0].text
|
||||
token_strs = completion.choices[0].logprobs.tokens
|
||||
@@ -104,22 +102,22 @@ async def test_chat_return_tokens_as_token_ids_completion(server_fixture):
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You like to respond in only emojis, like 🎉"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Please write some emojis: 🐱🐶🎉"
|
||||
}],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You like to respond in only emojis, like 🎉",
|
||||
},
|
||||
{"role": "user", "content": "Please write some emojis: 🐱🐶🎉"},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=8,
|
||||
logprobs=True,
|
||||
extra_body=request_args)
|
||||
extra_body=request_args,
|
||||
)
|
||||
|
||||
text = response.choices[0].message.content
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
token_ids = []
|
||||
for logprob_content in response.choices[0].logprobs.content:
|
||||
token_ids.append(
|
||||
int(logprob_content.token.removeprefix("token_id:")))
|
||||
token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
|
||||
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
|
||||
|
||||
@@ -51,26 +51,31 @@ class TestCase(NamedTuple):
|
||||
model_name=MODEL_NAME,
|
||||
base_url=["v1"], # http://localhost:8000/v1
|
||||
api_key=ERROR_API_KEY,
|
||||
expected_error=openai.AuthenticationError),
|
||||
expected_error=openai.AuthenticationError,
|
||||
),
|
||||
TestCase(
|
||||
model_name=MODEL_NAME,
|
||||
base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1
|
||||
api_key=ERROR_API_KEY,
|
||||
expected_error=openai.AuthenticationError),
|
||||
expected_error=openai.AuthenticationError,
|
||||
),
|
||||
TestCase(
|
||||
model_name=MODEL_NAME,
|
||||
base_url=["v1"], # http://localhost:8000/v1
|
||||
api_key=API_KEY,
|
||||
expected_error=None),
|
||||
expected_error=None,
|
||||
),
|
||||
TestCase(
|
||||
model_name=MODEL_NAME,
|
||||
base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1
|
||||
api_key=API_KEY,
|
||||
expected_error=None),
|
||||
expected_error=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer,
|
||||
test_case: TestCase):
|
||||
async def test_chat_session_root_path_with_api_key(
|
||||
server: RemoteOpenAIServer, test_case: TestCase
|
||||
):
|
||||
saying: str = "Here is a common saying about apple. An apple a day, keeps"
|
||||
ctx = contextlib.nullcontext()
|
||||
if test_case.expected_error is not None:
|
||||
@@ -79,20 +84,16 @@ async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer,
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=test_case.api_key,
|
||||
base_url=server.url_for(*test_case.base_url),
|
||||
max_retries=0)
|
||||
max_retries=0,
|
||||
)
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=test_case.model_name,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "tell me a common saying"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": saying
|
||||
}],
|
||||
extra_body={
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False
|
||||
})
|
||||
messages=[
|
||||
{"role": "user", "content": "tell me a common saying"},
|
||||
{"role": "assistant", "content": saying},
|
||||
],
|
||||
extra_body={"continue_final_message": True, "add_generation_prompt": False},
|
||||
)
|
||||
|
||||
assert chat_completion.id is not None
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
@@ -35,15 +35,24 @@ INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/re
|
||||
|
||||
|
||||
def test_empty_file():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
with (
|
||||
tempfile.NamedTemporaryFile("w") as input_file,
|
||||
tempfile.NamedTemporaryFile("r") as output_file,
|
||||
):
|
||||
input_file.write("")
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
"vllm", "run-batch", "-i", input_file.name, "-o", output_file.name,
|
||||
"--model", "intfloat/multilingual-e5-small"
|
||||
], )
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"vllm",
|
||||
"run-batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"intfloat/multilingual-e5-small",
|
||||
],
|
||||
)
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
@@ -53,15 +62,24 @@ def test_empty_file():
|
||||
|
||||
|
||||
def test_completions():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
with (
|
||||
tempfile.NamedTemporaryFile("w") as input_file,
|
||||
tempfile.NamedTemporaryFile("r") as output_file,
|
||||
):
|
||||
input_file.write(INPUT_BATCH)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
"vllm", "run-batch", "-i", input_file.name, "-o", output_file.name,
|
||||
"--model", "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
], )
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"vllm",
|
||||
"run-batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"NousResearch/Meta-Llama-3-8B-Instruct",
|
||||
],
|
||||
)
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
@@ -77,30 +95,48 @@ def test_completions_invalid_input():
|
||||
"""
|
||||
Ensure that we fail when the input doesn't conform to the openai api.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
with (
|
||||
tempfile.NamedTemporaryFile("w") as input_file,
|
||||
tempfile.NamedTemporaryFile("r") as output_file,
|
||||
):
|
||||
input_file.write(INVALID_INPUT_BATCH)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
"vllm", "run-batch", "-i", input_file.name, "-o", output_file.name,
|
||||
"--model", "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
], )
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"vllm",
|
||||
"run-batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"NousResearch/Meta-Llama-3-8B-Instruct",
|
||||
],
|
||||
)
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode != 0, f"{proc=}"
|
||||
|
||||
|
||||
def test_embeddings():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
with (
|
||||
tempfile.NamedTemporaryFile("w") as input_file,
|
||||
tempfile.NamedTemporaryFile("r") as output_file,
|
||||
):
|
||||
input_file.write(INPUT_EMBEDDING_BATCH)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
"vllm", "run-batch", "-i", input_file.name, "-o", output_file.name,
|
||||
"--model", "intfloat/multilingual-e5-small"
|
||||
], )
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"vllm",
|
||||
"run-batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"intfloat/multilingual-e5-small",
|
||||
],
|
||||
)
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
@@ -112,24 +148,26 @@ def test_embeddings():
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_batch",
|
||||
[INPUT_SCORE_BATCH, INPUT_RERANK_BATCH])
|
||||
@pytest.mark.parametrize("input_batch", [INPUT_SCORE_BATCH, INPUT_RERANK_BATCH])
|
||||
def test_score(input_batch):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
with (
|
||||
tempfile.NamedTemporaryFile("w") as input_file,
|
||||
tempfile.NamedTemporaryFile("r") as output_file,
|
||||
):
|
||||
input_file.write(input_batch)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
"vllm",
|
||||
"run-batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"BAAI/bge-reranker-v2-m3",
|
||||
], )
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"vllm",
|
||||
"run-batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"BAAI/bge-reranker-v2-m3",
|
||||
],
|
||||
)
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
|
||||
@@ -15,8 +15,7 @@ import pytest_asyncio
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
@@ -31,14 +30,17 @@ GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=[True, False],
|
||||
ids=["with_tool_parser", "without_tool_parser"])
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=[True, False],
|
||||
ids=["with_tool_parser", "without_tool_parser"],
|
||||
)
|
||||
def with_tool_parser(request) -> bool:
|
||||
return request.param
|
||||
|
||||
@@ -56,21 +58,25 @@ def default_server_args(with_tool_parser: bool):
|
||||
"0.8",
|
||||
]
|
||||
if with_tool_parser:
|
||||
args.extend([
|
||||
"--tool-call-parser",
|
||||
"openai",
|
||||
"--enable-auto-tool-choice",
|
||||
])
|
||||
args.extend(
|
||||
[
|
||||
"--tool-call-parser",
|
||||
"openai",
|
||||
"--enable-auto-tool-choice",
|
||||
]
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
|
||||
default_server_args: list[str]):
|
||||
def gptoss_server(
|
||||
monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str]
|
||||
):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
|
||||
default_server_args) as remote_server:
|
||||
with RemoteOpenAIServer(
|
||||
GPT_OSS_MODEL_NAME, default_server_args
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -81,44 +87,41 @@ async def gptoss_client(gptoss_server):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI,
|
||||
with_tool_parser: bool):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
async def test_gpt_oss_chat_tool_call_streaming(
|
||||
gptoss_client: OpenAI, with_tool_parser: bool
|
||||
):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, TX?"
|
||||
},
|
||||
{"role": "user", "content": "What is the weather in Dallas, TX?"},
|
||||
]
|
||||
|
||||
stream = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools if with_tool_parser else None,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
name = None
|
||||
args_buf = ""
|
||||
@@ -143,43 +146,34 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
||||
with_tool_parser: bool):
|
||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool):
|
||||
if not with_tool_parser:
|
||||
pytest.skip("skip non-tool for multi-turn tests")
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, TX with celsius?"
|
||||
},
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "What is the weather in Dallas, TX with celsius?"},
|
||||
]
|
||||
|
||||
first = await gptoss_client.chat.completions.create(
|
||||
@@ -197,10 +191,9 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
||||
assert not first_msg.content
|
||||
|
||||
messages.append({"role": "assistant", "content": args1})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "Now convert to celsius and return JSON only"
|
||||
})
|
||||
messages.append(
|
||||
{"role": "user", "content": "Now convert to celsius and return JSON only"}
|
||||
)
|
||||
|
||||
second = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
@@ -209,8 +202,9 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
||||
temperature=0.0,
|
||||
)
|
||||
second_msg = second.choices[0].message
|
||||
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
|
||||
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0)
|
||||
assert (second_msg.content is not None and len(second_msg.content) > 0) or (
|
||||
second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
|
||||
)
|
||||
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -218,7 +212,7 @@ MODEL_NAME_SHORT = "gpt2"
|
||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||
BASE_MODEL_PATHS = [
|
||||
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT)
|
||||
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||
]
|
||||
|
||||
|
||||
@@ -251,21 +245,33 @@ class MockModelConfig:
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM,
|
||||
model_config: MockModelConfig) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=model_config)
|
||||
serving_chat = OpenAIServingChat(engine,
|
||||
model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
def _build_serving_chat(
|
||||
engine: AsyncLLM, model_config: MockModelConfig
|
||||
) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=model_config,
|
||||
)
|
||||
serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(request_id, engine_prompt, sampling_params,
|
||||
*, lora_request, trace_headers, priority):
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
@@ -274,7 +280,6 @@ def _build_serving_chat(engine: AsyncLLM,
|
||||
|
||||
@dataclass
|
||||
class MockEngine:
|
||||
|
||||
async def get_model_config(self):
|
||||
return MockModelConfig()
|
||||
|
||||
@@ -284,13 +289,15 @@ async def _async_serving_chat_init():
|
||||
model_config = await engine.get_model_config()
|
||||
|
||||
models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS)
|
||||
serving_completion = OpenAIServingChat(engine,
|
||||
model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
serving_completion = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None,
|
||||
)
|
||||
return serving_completion
|
||||
|
||||
|
||||
@@ -336,10 +343,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
messages=[{"role": "user", "content": "what is 1+1?"}],
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
@@ -371,10 +375,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
# Test Case 1: No max_tokens specified in request
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
messages=[{"role": "user", "content": "what is 1+1?"}],
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
@@ -416,10 +417,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
# Test case 1: No max_tokens specified, defaults to context_window
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
messages=[{"role": "user", "content": "what is 1+1?"}],
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
@@ -446,11 +444,10 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_could_load_correct_generation_config():
|
||||
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"temperature": 0.5,
|
||||
"repetition_penalty": 1.05
|
||||
"repetition_penalty": 1.05,
|
||||
}
|
||||
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
@@ -462,10 +459,7 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
messages=[{"role": "user", "content": "what is 1+1?"}],
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
@@ -508,10 +502,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
# Test cache_salt
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
messages=[{"role": "user", "content": "what is 1+1?"}],
|
||||
)
|
||||
|
||||
# By default, cache_salt in the engine prompt is not set
|
||||
|
||||
@@ -34,7 +34,8 @@ def serving() -> OpenAIServing:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_mistral_tokenizer_does_not_block_event_loop(
|
||||
serving: OpenAIServing):
|
||||
serving: OpenAIServing,
|
||||
):
|
||||
expected_tokens = [1, 2, 3]
|
||||
|
||||
# Mock the blocking version to sleep
|
||||
@@ -45,10 +46,9 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(
|
||||
mock_tokenizer = Mock(spec=MistralTokenizer)
|
||||
mock_tokenizer.apply_chat_template.side_effect = mocked_apply_chat_template
|
||||
|
||||
task = serving._apply_mistral_chat_template_async(tokenizer=mock_tokenizer,
|
||||
messages=[],
|
||||
chat_template=None,
|
||||
tools=[])
|
||||
task = serving._apply_mistral_chat_template_async(
|
||||
tokenizer=mock_tokenizer, messages=[], chat_template=None, tools=[]
|
||||
)
|
||||
|
||||
# Ensure the event loop is not blocked
|
||||
blocked_count = 0
|
||||
@@ -66,4 +66,4 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(
|
||||
# Ensure task completes
|
||||
tokens = await task
|
||||
assert tokens == expected_tokens, "Mocked blocking tokenizer was not called"
|
||||
assert blocked_count == 0, ("Event loop blocked during tokenization")
|
||||
assert blocked_count == 0, "Event loop blocked during tokenization"
|
||||
|
||||
@@ -8,19 +8,20 @@ import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
UnloadLoRAAdapterRequest)
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
UnloadLoRAAdapterRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
LORA_LOADING_SUCCESS_MESSAGE = (
|
||||
"Success: LoRA adapter '{lora_name}' added successfully.")
|
||||
LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully."
|
||||
LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||
"Success: LoRA adapter '{lora_name}' removed successfully.")
|
||||
"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
)
|
||||
|
||||
|
||||
async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
@@ -29,10 +30,12 @@ async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
# Set the max_model_len attribute to avoid missing attribute
|
||||
mock_model_config.max_model_len = 2048
|
||||
|
||||
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config,
|
||||
lora_modules=None)
|
||||
serving_models = OpenAIServingModels(
|
||||
engine_client=mock_engine_client,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config,
|
||||
lora_modules=None,
|
||||
)
|
||||
await serving_models.init_static_loras()
|
||||
|
||||
return serving_models
|
||||
@@ -42,19 +45,18 @@ async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
async def test_serving_model_name():
|
||||
serving_models = await _async_serving_models_init()
|
||||
assert serving_models.model_name(None) == MODEL_NAME
|
||||
request = LoRARequest(lora_name="adapter",
|
||||
lora_path="/path/to/adapter2",
|
||||
lora_int_id=1)
|
||||
request = LoRARequest(
|
||||
lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1
|
||||
)
|
||||
assert serving_models.model_name(request) == request.lora_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_success():
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoRAAdapterRequest(lora_name="adapter",
|
||||
lora_path="/path/to/adapter2")
|
||||
request = LoadLoRAAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2")
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter")
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
assert "adapter" in serving_models.lora_requests
|
||||
assert serving_models.lora_requests["adapter"].lora_name == "adapter"
|
||||
@@ -73,15 +75,16 @@ async def test_load_lora_adapter_missing_fields():
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_duplicate():
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoRAAdapterRequest(lora_name="adapter1",
|
||||
lora_path="/path/to/adapter1")
|
||||
request = LoadLoRAAdapterRequest(
|
||||
lora_name="adapter1", lora_path="/path/to/adapter1"
|
||||
)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
|
||||
lora_name='adapter1')
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
request = LoadLoRAAdapterRequest(lora_name="adapter1",
|
||||
lora_path="/path/to/adapter1")
|
||||
request = LoadLoRAAdapterRequest(
|
||||
lora_name="adapter1", lora_path="/path/to/adapter1"
|
||||
)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InvalidUserInput"
|
||||
@@ -92,15 +95,15 @@ async def test_load_lora_adapter_duplicate():
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_success():
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoRAAdapterRequest(lora_name="adapter1",
|
||||
lora_path="/path/to/adapter1")
|
||||
request = LoadLoRAAdapterRequest(
|
||||
lora_name="adapter1", lora_path="/path/to/adapter1"
|
||||
)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
request = UnloadLoRAAdapterRequest(lora_name="adapter1")
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
|
||||
lora_name='adapter1')
|
||||
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
|
||||
assert len(serving_models.lora_requests) == 0
|
||||
|
||||
|
||||
|
||||
@@ -34,11 +34,9 @@ class MockConversationContext(ConversationContext):
|
||||
def render_for_completion(self):
|
||||
return []
|
||||
|
||||
async def init_tool_sessions(self, tool_server, exit_stack, request_id,
|
||||
mcp_tools):
|
||||
async def init_tool_sessions(self, tool_server, exit_stack, request_id, mcp_tools):
|
||||
self.init_tool_sessions_called = True
|
||||
self.init_tool_sessions_args = (tool_server, exit_stack, request_id,
|
||||
mcp_tools)
|
||||
self.init_tool_sessions_args = (tool_server, exit_stack, request_id, mcp_tools)
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
pass
|
||||
@@ -96,35 +94,31 @@ class TestInitializeToolSessions:
|
||||
return instance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_tool_sessions(self, serving_responses_instance,
|
||||
mock_context, mock_exit_stack):
|
||||
async def test_initialize_tool_sessions(
|
||||
self, serving_responses_instance, mock_context, mock_exit_stack
|
||||
):
|
||||
"""Test that method works correctly with only MCP tools"""
|
||||
|
||||
request = ResponsesRequest(input="test input", tools=[])
|
||||
|
||||
# Call the method
|
||||
await serving_responses_instance._initialize_tool_sessions(
|
||||
request, mock_context, mock_exit_stack)
|
||||
request, mock_context, mock_exit_stack
|
||||
)
|
||||
assert mock_context.init_tool_sessions_called is False
|
||||
|
||||
# Create only MCP tools
|
||||
tools = [
|
||||
{
|
||||
"type": "web_search_preview"
|
||||
},
|
||||
{
|
||||
"type": "code_interpreter",
|
||||
"container": {
|
||||
"type": "auto"
|
||||
}
|
||||
},
|
||||
{"type": "web_search_preview"},
|
||||
{"type": "code_interpreter", "container": {"type": "auto"}},
|
||||
]
|
||||
|
||||
request = ResponsesRequest(input="test input", tools=tools)
|
||||
|
||||
# Call the method
|
||||
await serving_responses_instance._initialize_tool_sessions(
|
||||
request, mock_context, mock_exit_stack)
|
||||
request, mock_context, mock_exit_stack
|
||||
)
|
||||
|
||||
# Verify that init_tool_sessions was called
|
||||
assert mock_context.init_tool_sessions_called
|
||||
@@ -165,25 +159,20 @@ class TestValidateGeneratorInput:
|
||||
"""Test _validate_generator_input with valid prompt length"""
|
||||
# Create an engine prompt with valid length (less than max_model_len)
|
||||
valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len
|
||||
engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=valid_prompt_token_ids)
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=valid_prompt_token_ids)
|
||||
|
||||
# Call the method
|
||||
result = serving_responses_instance._validate_generator_input(
|
||||
engine_prompt)
|
||||
result = serving_responses_instance._validate_generator_input(engine_prompt)
|
||||
|
||||
# Should return None for valid input
|
||||
assert result is None
|
||||
|
||||
# create an invalid engine prompt
|
||||
invalid_prompt_token_ids = list(
|
||||
range(200)) # 100 tokens >= 100 max_model_len
|
||||
engine_prompt = EngineTokensPrompt(
|
||||
prompt_token_ids=invalid_prompt_token_ids)
|
||||
invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=invalid_prompt_token_ids)
|
||||
|
||||
# Call the method
|
||||
result = serving_responses_instance._validate_generator_input(
|
||||
engine_prompt)
|
||||
result = serving_responses_instance._validate_generator_input(engine_prompt)
|
||||
|
||||
# Should return an ErrorResponse
|
||||
assert result is not None
|
||||
|
||||
@@ -24,16 +24,13 @@ async def test_shutdown_on_engine_failure():
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
async with remote_server.get_async_client() as client:
|
||||
|
||||
with pytest.raises(
|
||||
(openai.APIConnectionError, openai.InternalServerError)):
|
||||
with pytest.raises((openai.APIConnectionError, openai.InternalServerError)):
|
||||
# Asking for lots of prompt logprobs will currently crash the
|
||||
# engine. This may change in the future when that bug is fixed
|
||||
prompt = "Hello " * 4000
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
extra_body={"prompt_logprobs": 10})
|
||||
model=MODEL_NAME, prompt=prompt, extra_body={"prompt_logprobs": 10}
|
||||
)
|
||||
|
||||
# Now the server should shut down
|
||||
return_code = remote_server.proc.wait(timeout=8)
|
||||
|
||||
@@ -29,7 +29,7 @@ def server():
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"--model-impl",
|
||||
"terratorch"
|
||||
"terratorch",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@@ -39,7 +39,6 @@ def server():
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_request(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
|
||||
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
|
||||
|
||||
@@ -47,40 +46,39 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str):
|
||||
torch.save(pixel_values, buffer_tiff)
|
||||
buffer_tiff.seek(0)
|
||||
binary_data = buffer_tiff.read()
|
||||
base64_tensor_embedding = base64.b64encode(binary_data).decode('utf-8')
|
||||
base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
buffer_coord = io.BytesIO()
|
||||
torch.save(location_coords, buffer_coord)
|
||||
buffer_coord.seek(0)
|
||||
binary_data = buffer_coord.read()
|
||||
base64_coord_embedding = base64.b64encode(binary_data).decode('utf-8')
|
||||
base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
prompt = {
|
||||
"model":
|
||||
model_name,
|
||||
"additional_data": {
|
||||
"prompt_token_ids": [1]
|
||||
},
|
||||
"encoding_format":
|
||||
"base64",
|
||||
"messages": [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"pixel_values": base64_tensor_embedding,
|
||||
"location_coords": base64_coord_embedding,
|
||||
},
|
||||
}],
|
||||
}]
|
||||
"model": model_name,
|
||||
"additional_data": {"prompt_token_ids": [1]},
|
||||
"encoding_format": "base64",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"pixel_values": base64_tensor_embedding,
|
||||
"location_coords": base64_coord_embedding,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# test single pooling
|
||||
response = requests.post(server.url_for("pooling"), json=prompt)
|
||||
response.raise_for_status()
|
||||
|
||||
output = response.json()["data"][0]['data']
|
||||
output = response.json()["data"][0]["data"]
|
||||
|
||||
np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32)
|
||||
|
||||
|
||||
@@ -20,14 +20,12 @@ def test_sleep_mode():
|
||||
"--enable-sleep-mode",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
args,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE": "1",
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
}) as remote_server:
|
||||
response = requests.post(remote_server.url_for("sleep"),
|
||||
params={"level": "1"})
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME,
|
||||
args,
|
||||
env_dict={"VLLM_SERVER_DEV_MODE": "1", "CUDA_VISIBLE_DEVICES": "0"},
|
||||
) as remote_server:
|
||||
response = requests.post(remote_server.url_for("sleep"), params={"level": "1"})
|
||||
assert response.status_code == 200
|
||||
response = requests.get(remote_server.url_for("is_sleeping"))
|
||||
assert response.status_code == 200
|
||||
@@ -40,12 +38,12 @@ def test_sleep_mode():
|
||||
assert response.json().get("is_sleeping") is False
|
||||
|
||||
# test wake up with tags
|
||||
response = requests.post(remote_server.url_for("sleep"),
|
||||
params={"level": "1"})
|
||||
response = requests.post(remote_server.url_for("sleep"), params={"level": "1"})
|
||||
assert response.status_code == 200
|
||||
|
||||
response = requests.post(remote_server.url_for("wake_up"),
|
||||
params={"tags": ["weights"]})
|
||||
response = requests.post(
|
||||
remote_server.url_for("wake_up"), params={"tags": ["weights"]}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# is sleeping should be false after waking up any part of the engine
|
||||
@@ -53,8 +51,9 @@ def test_sleep_mode():
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("is_sleeping") is True
|
||||
|
||||
response = requests.post(remote_server.url_for("wake_up"),
|
||||
params={"tags": ["kv_cache"]})
|
||||
response = requests.post(
|
||||
remote_server.url_for("wake_up"), params={"tags": ["kv_cache"]}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = requests.get(remote_server.url_for("is_sleeping"))
|
||||
|
||||
@@ -11,7 +11,10 @@ import torch.cuda
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model)
|
||||
TensorizerConfig,
|
||||
tensorize_lora_adapter,
|
||||
tensorize_vllm_model,
|
||||
)
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@@ -29,21 +32,20 @@ def cleanup():
|
||||
_cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@pytest.fixture(scope="module")
|
||||
def tmp_dir():
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
yield path
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@pytest.fixture(scope="module")
|
||||
def model_uri(tmp_dir):
|
||||
yield f"{tmp_dir}/model.tensors"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tensorize_model_and_lora(tmp_dir, model_uri):
|
||||
tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri,
|
||||
lora_dir=tmp_dir)
|
||||
tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, lora_dir=tmp_dir)
|
||||
args = EngineArgs(model=MODEL_NAME)
|
||||
|
||||
tensorize_lora_adapter(LORA_PATH, tensorizer_config)
|
||||
@@ -66,8 +68,11 @@ def server(model_uri, tensorize_model_and_lora):
|
||||
|
||||
## Start OpenAI API server
|
||||
args = [
|
||||
"--load-format", "tensorizer", "--served-model-name", MODEL_NAME,
|
||||
"--enable-lora"
|
||||
"--load-format",
|
||||
"tensorizer",
|
||||
"--served-model-name",
|
||||
MODEL_NAME,
|
||||
"--enable-lora",
|
||||
]
|
||||
|
||||
model_dir = os.path.dirname(model_uri)
|
||||
@@ -85,10 +90,9 @@ async def client(server):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
||||
_cleanup()
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
completion = await client.completions.create(
|
||||
model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0
|
||||
)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
@@ -97,4 +101,5 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
||||
assert len(completion.choices[0].text) >= 5
|
||||
assert completion.choices[0].finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11
|
||||
)
|
||||
|
||||
@@ -6,8 +6,7 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf)
|
||||
from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
@@ -23,7 +22,8 @@ def server():
|
||||
MODEL_NAME,
|
||||
allow_patterns=["*"],
|
||||
cache_dir=MODEL_PATH,
|
||||
ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"])
|
||||
ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"],
|
||||
)
|
||||
args = [
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
@@ -61,13 +61,14 @@ async def test_token_in_token_out_and_logprobs(server):
|
||||
)
|
||||
|
||||
# Verify all fields are present
|
||||
assert (completion.choices[0].token_ids is not None
|
||||
and 0 < len(completion.choices[0].token_ids) <= 20)
|
||||
assert (
|
||||
completion.choices[0].token_ids is not None
|
||||
and 0 < len(completion.choices[0].token_ids) <= 20
|
||||
)
|
||||
assert completion.choices[0].prompt_token_ids is not None
|
||||
|
||||
# Decode prompt tokens
|
||||
if completion.choices[0].prompt_token_ids:
|
||||
prompt_text = tokenizer.decode(
|
||||
completion.choices[0].prompt_token_ids)
|
||||
prompt_text = tokenizer.decode(completion.choices[0].prompt_token_ids)
|
||||
# The decoded prompt should match or close to original prompt
|
||||
assert prompt_text == text
|
||||
|
||||
@@ -53,19 +53,20 @@ async def test_tokenize_completions(
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast")
|
||||
|
||||
for add_special in [False, True]:
|
||||
prompt = "vllm1 This is a test prompt."
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||
|
||||
response = requests.post(server.url_for("tokenize"),
|
||||
json={
|
||||
"add_special_tokens": add_special,
|
||||
"model": model_name,
|
||||
"prompt": prompt
|
||||
})
|
||||
response = requests.post(
|
||||
server.url_for("tokenize"),
|
||||
json={
|
||||
"add_special_tokens": add_special,
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
@@ -86,48 +87,39 @@ async def test_tokenize_chat(
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast")
|
||||
|
||||
for add_generation in [False, True]:
|
||||
for add_special in [False, True]:
|
||||
conversation = [{
|
||||
"role": "user",
|
||||
"content": "Hi there!"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Nice to meet you!"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Can I ask a question? vllm1"
|
||||
}]
|
||||
conversation = [
|
||||
{"role": "user", "content": "Hi there!"},
|
||||
{"role": "assistant", "content": "Nice to meet you!"},
|
||||
{"role": "user", "content": "Can I ask a question? vllm1"},
|
||||
]
|
||||
for continue_final in [False, True]:
|
||||
if add_generation and continue_final:
|
||||
continue
|
||||
if continue_final:
|
||||
conversation.append({
|
||||
"role": "assistant",
|
||||
"content": "Sure,"
|
||||
})
|
||||
conversation.append({"role": "assistant", "content": "Sure,"})
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
add_generation_prompt=add_generation,
|
||||
continue_final_message=continue_final,
|
||||
conversation=conversation,
|
||||
tokenize=False)
|
||||
tokens = tokenizer.encode(prompt,
|
||||
add_special_tokens=add_special)
|
||||
tokenize=False,
|
||||
)
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||
|
||||
response = requests.post(server.url_for("tokenize"),
|
||||
json={
|
||||
"add_generation_prompt":
|
||||
add_generation,
|
||||
"continue_final_message":
|
||||
continue_final,
|
||||
"add_special_tokens": add_special,
|
||||
"messages": conversation,
|
||||
"model": model_name
|
||||
})
|
||||
response = requests.post(
|
||||
server.url_for("tokenize"),
|
||||
json={
|
||||
"add_generation_prompt": add_generation,
|
||||
"continue_final_message": continue_final,
|
||||
"add_special_tokens": add_special,
|
||||
"messages": conversation,
|
||||
"model": model_name,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
@@ -148,41 +140,35 @@ async def test_tokenize_chat_with_tools(
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast")
|
||||
|
||||
for add_generation in [False, True]:
|
||||
for add_special in [False, True]:
|
||||
conversation = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What's the weather like in Paris today?",
|
||||
}]
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris today?",
|
||||
}
|
||||
]
|
||||
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
}
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
for continue_final in [False, True]:
|
||||
if add_generation and continue_final:
|
||||
continue
|
||||
if continue_final:
|
||||
conversation.append({
|
||||
"role": "assistant",
|
||||
"content": "Sure,"
|
||||
})
|
||||
conversation.append({"role": "assistant", "content": "Sure,"})
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
add_generation_prompt=add_generation,
|
||||
@@ -191,8 +177,7 @@ async def test_tokenize_chat_with_tools(
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
)
|
||||
tokens = tokenizer.encode(prompt,
|
||||
add_special_tokens=add_special)
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||
|
||||
response = requests.post(
|
||||
server.url_for("tokenize"),
|
||||
@@ -225,17 +210,12 @@ async def test_tokenize_with_return_token_strs(
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast")
|
||||
|
||||
prompt = "This is a token_strs test prompt! vllm1"
|
||||
response = requests.post(
|
||||
server.url_for("tokenize"),
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"model": model_name,
|
||||
"return_token_strs": True
|
||||
},
|
||||
json={"prompt": prompt, "model": model_name, "return_token_strs": True},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -260,17 +240,14 @@ async def test_detokenize(
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast")
|
||||
|
||||
prompt = "This is a test prompt. vllm1"
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
|
||||
response = requests.post(server.url_for("detokenize"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"tokens": tokens
|
||||
})
|
||||
response = requests.post(
|
||||
server.url_for("detokenize"), json={"model": model_name, "tokens": tokens}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
assert response.json() == {"prompt": prompt}
|
||||
@@ -319,14 +296,15 @@ async def test_tokenizer_info_schema(server: RemoteOpenAIServer):
|
||||
}
|
||||
for field, expected_type in field_types.items():
|
||||
if field in result and result[field] is not None:
|
||||
assert isinstance(
|
||||
result[field],
|
||||
expected_type), (f"{field} should be {expected_type.__name__}")
|
||||
assert isinstance(result[field], expected_type), (
|
||||
f"{field} should be {expected_type.__name__}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_added_tokens_structure(
|
||||
server: RemoteOpenAIServer, ):
|
||||
server: RemoteOpenAIServer,
|
||||
):
|
||||
"""Test added_tokens_decoder structure if present."""
|
||||
response = requests.get(server.url_for("tokenizer_info"))
|
||||
response.raise_for_status()
|
||||
@@ -337,25 +315,23 @@ async def test_tokenizer_info_added_tokens_structure(
|
||||
assert isinstance(token_id, str), "Token IDs should be strings"
|
||||
assert isinstance(token_info, dict), "Token info should be a dict"
|
||||
assert "content" in token_info, "Token info should have content"
|
||||
assert "special" in token_info, (
|
||||
"Token info should have special flag")
|
||||
assert isinstance(token_info["special"],
|
||||
bool), ("Special flag should be boolean")
|
||||
assert "special" in token_info, "Token info should have special flag"
|
||||
assert isinstance(token_info["special"], bool), (
|
||||
"Special flag should be boolean"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_consistency_with_tokenize(
|
||||
server: RemoteOpenAIServer, ):
|
||||
server: RemoteOpenAIServer,
|
||||
):
|
||||
"""Test that tokenizer info is consistent with tokenization endpoint."""
|
||||
info_response = requests.get(server.url_for("tokenizer_info"))
|
||||
info_response.raise_for_status()
|
||||
info = info_response.json()
|
||||
tokenize_response = requests.post(
|
||||
server.url_for("tokenize"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Hello world!"
|
||||
},
|
||||
json={"model": MODEL_NAME, "prompt": "Hello world!"},
|
||||
)
|
||||
tokenize_response.raise_for_status()
|
||||
tokenize_result = tokenize_response.json()
|
||||
@@ -363,7 +339,8 @@ async def test_tokenizer_info_consistency_with_tokenize(
|
||||
tokenize_max_len = tokenize_result.get("max_model_len")
|
||||
if info_max_len and tokenize_max_len:
|
||||
assert info_max_len >= tokenize_max_len, (
|
||||
"Info max length should be >= tokenize max length")
|
||||
"Info max length should be >= tokenize max length"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -374,6 +351,5 @@ async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
|
||||
result = response.json()
|
||||
chat_template = result.get("chat_template")
|
||||
if chat_template:
|
||||
assert isinstance(chat_template,
|
||||
str), ("Chat template should be a string")
|
||||
assert isinstance(chat_template, str), "Chat template should be a string"
|
||||
assert chat_template.strip(), "Chat template should not be empty"
|
||||
|
||||
@@ -17,8 +17,12 @@ from ...utils import RemoteOpenAIServer
|
||||
MODEL_NAME = "openai/whisper-large-v3-turbo"
|
||||
SERVER_ARGS = ["--enforce-eager"]
|
||||
MISTRAL_FORMAT_ARGS = [
|
||||
"--tokenizer_mode", "mistral", "--config_format", "mistral",
|
||||
"--load_format", "mistral"
|
||||
"--tokenizer_mode",
|
||||
"mistral",
|
||||
"--config_format",
|
||||
"mistral",
|
||||
"--load_format",
|
||||
"mistral",
|
||||
]
|
||||
|
||||
|
||||
@@ -36,8 +40,8 @@ async def client(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
|
||||
"model_name", ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"]
|
||||
)
|
||||
async def test_basic_audio(mary_had_lamb, model_name):
|
||||
server_args = ["--enforce-eager"]
|
||||
|
||||
@@ -52,10 +56,11 @@ async def test_basic_audio(mary_had_lamb, model_name):
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
temperature=0.0,
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out['text']
|
||||
out_usage = out['usage']
|
||||
out_text = out["text"]
|
||||
out_usage = out["usage"]
|
||||
assert "Mary had a little lamb," in out_text
|
||||
assert out_usage["seconds"] == 16, out_usage["seconds"]
|
||||
|
||||
@@ -74,8 +79,9 @@ async def test_basic_audio_gemma(foscolo):
|
||||
file=foscolo,
|
||||
language="it",
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
temperature=0.0,
|
||||
)
|
||||
out = json.loads(transcription)["text"]
|
||||
assert "da cui vergine nacque Venere" in out
|
||||
|
||||
|
||||
@@ -85,24 +91,21 @@ async def test_non_asr_model(winning_call):
|
||||
model_name = "JackFram/llama-68m"
|
||||
with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.transcriptions.create(model=model_name,
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0)
|
||||
res = await client.audio.transcriptions.create(
|
||||
model=model_name, file=winning_call, language="en", temperature=0.0
|
||||
)
|
||||
err = res.error
|
||||
assert err["code"] == 400 and not res.text
|
||||
assert err[
|
||||
"message"] == "The model does not support Transcriptions API"
|
||||
assert err["message"] == "The model does not support Transcriptions API"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_requests(mary_had_lamb, client):
|
||||
# invalid language
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.audio.transcriptions.create(model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="hh",
|
||||
temperature=0.0)
|
||||
await client.audio.transcriptions.create(
|
||||
model=MODEL_NAME, file=mary_had_lamb, language="hh", temperature=0.0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -114,17 +117,18 @@ async def test_long_audio_request(mary_had_lamb, client):
|
||||
repeated_audio = np.tile(audio, 10)
|
||||
# Repeated audio to buffer
|
||||
buffer = io.BytesIO()
|
||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||
sf.write(buffer, repeated_audio, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
transcription = await client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=buffer,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
temperature=0.0,
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out['text']
|
||||
out_usage = out['usage']
|
||||
out_text = out["text"]
|
||||
out_usage = out["usage"]
|
||||
counts = out_text.count("Mary had a little lamb")
|
||||
assert counts == 10, counts
|
||||
assert out_usage["seconds"] == 161, out_usage["seconds"]
|
||||
@@ -135,10 +139,8 @@ async def test_completion_endpoints(client):
|
||||
# text to text model
|
||||
res = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}])
|
||||
messages=[{"role": "system", "content": "You are a helpful assistant."}],
|
||||
)
|
||||
err = res.error
|
||||
assert err["code"] == 400
|
||||
assert err["message"] == "The model does not support Chat Completions API"
|
||||
@@ -157,16 +159,19 @@ async def test_streaming_response(winning_call, client):
|
||||
file=winning_call,
|
||||
response_format="json",
|
||||
language="en",
|
||||
temperature=0.0)
|
||||
res = await client.audio.transcriptions.create(model=MODEL_NAME,
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
timeout=30)
|
||||
temperature=0.0,
|
||||
)
|
||||
res = await client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
timeout=30,
|
||||
)
|
||||
# Reconstruct from chunks and validate
|
||||
async for chunk in res:
|
||||
text = chunk.choices[0]['delta']['content']
|
||||
text = chunk.choices[0]["delta"]["content"]
|
||||
transcription += text
|
||||
|
||||
assert transcription == res_no_stream.text
|
||||
@@ -180,9 +185,9 @@ async def test_stream_options(winning_call, client):
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
extra_body=dict(stream_include_usage=True,
|
||||
stream_continuous_usage_stats=True),
|
||||
timeout=30)
|
||||
extra_body=dict(stream_include_usage=True, stream_continuous_usage_stats=True),
|
||||
timeout=30,
|
||||
)
|
||||
final = False
|
||||
continuous = True
|
||||
async for chunk in res:
|
||||
@@ -190,7 +195,7 @@ async def test_stream_options(winning_call, client):
|
||||
# final usage sent
|
||||
final = True
|
||||
else:
|
||||
continuous = continuous and hasattr(chunk, 'usage')
|
||||
continuous = continuous and hasattr(chunk, "usage")
|
||||
assert final and continuous
|
||||
|
||||
|
||||
@@ -198,27 +203,31 @@ async def test_stream_options(winning_call, client):
|
||||
async def test_sampling_params(mary_had_lamb, client):
|
||||
"""
|
||||
Compare sampling with params and greedy sampling to assert results
|
||||
are different when extreme sampling parameters values are picked.
|
||||
are different when extreme sampling parameters values are picked.
|
||||
"""
|
||||
transcription = await client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
temperature=0.8,
|
||||
extra_body=dict(seed=42,
|
||||
repetition_penalty=1.9,
|
||||
top_k=12,
|
||||
top_p=0.4,
|
||||
min_p=0.5,
|
||||
frequency_penalty=1.8,
|
||||
presence_penalty=2.0))
|
||||
extra_body=dict(
|
||||
seed=42,
|
||||
repetition_penalty=1.9,
|
||||
top_k=12,
|
||||
top_p=0.4,
|
||||
min_p=0.5,
|
||||
frequency_penalty=1.8,
|
||||
presence_penalty=2.0,
|
||||
),
|
||||
)
|
||||
|
||||
greedy_transcription = await client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
extra_body=dict(seed=42))
|
||||
extra_body=dict(seed=42),
|
||||
)
|
||||
|
||||
assert greedy_transcription.text != transcription.text
|
||||
|
||||
@@ -226,15 +235,16 @@ async def test_sampling_params(mary_had_lamb, client):
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_prompt(mary_had_lamb, client):
|
||||
prompt = "This is a speech, recorded in a phonograph."
|
||||
#Prompts should not omit the part of original prompt while transcribing.
|
||||
# Prompts should not omit the part of original prompt while transcribing.
|
||||
prefix = "The first words I spoke in the original phonograph"
|
||||
transcription = await client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
temperature=0.0,
|
||||
)
|
||||
out = json.loads(transcription)["text"]
|
||||
assert prefix in out
|
||||
transcription_wprompt = await client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
@@ -242,6 +252,7 @@ async def test_audio_prompt(mary_had_lamb, client):
|
||||
language="en",
|
||||
response_format="text",
|
||||
prompt=prompt,
|
||||
temperature=0.0)
|
||||
out_prompt = json.loads(transcription_wprompt)['text']
|
||||
temperature=0.0,
|
||||
)
|
||||
out_prompt = json.loads(transcription_wprompt)["text"]
|
||||
assert prefix in out_prompt
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import io
|
||||
|
||||
# imports for structured outputs tests
|
||||
import json
|
||||
|
||||
@@ -17,8 +18,9 @@ from ...utils import RemoteOpenAIServer
|
||||
SERVER_ARGS = ["--enforce-eager"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=["openai/whisper-small", "google/gemma-3n-E2B-it"])
|
||||
@pytest.fixture(
|
||||
scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
|
||||
)
|
||||
def server(request):
|
||||
# Parametrize over model name
|
||||
with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server:
|
||||
@@ -38,9 +40,9 @@ async def test_non_asr_model(foscolo):
|
||||
model_name = "JackFram/llama-68m"
|
||||
with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.translations.create(model=model_name,
|
||||
file=foscolo,
|
||||
temperature=0.0)
|
||||
res = await client.audio.translations.create(
|
||||
model=model_name, file=foscolo, temperature=0.0
|
||||
)
|
||||
err = res.error
|
||||
assert err["code"] == 400 and not res.text
|
||||
assert err["message"] == "The model does not support Translations API"
|
||||
@@ -56,8 +58,9 @@ async def test_basic_audio(foscolo, client_and_model):
|
||||
response_format="text",
|
||||
# TODO remove `language="it"` once language detection is implemented
|
||||
extra_body=dict(language="it", to_language="en"),
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip().lower()
|
||||
temperature=0.0,
|
||||
)
|
||||
out = json.loads(translation)["text"].strip().lower()
|
||||
assert "greek sea" in out
|
||||
|
||||
|
||||
@@ -72,8 +75,9 @@ async def test_audio_prompt(foscolo, client_and_model):
|
||||
prompt=prompt,
|
||||
extra_body=dict(language="it", to_language="en"),
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
temperature=0.0,
|
||||
)
|
||||
out = json.loads(transcription)["text"]
|
||||
assert "Nor will I ever touch the sacred" not in out
|
||||
assert prompt not in out
|
||||
|
||||
@@ -87,7 +91,8 @@ async def test_streaming_response(foscolo, client_and_model, server):
|
||||
file=foscolo,
|
||||
response_format="json",
|
||||
extra_body=dict(language="it", to_language="en", seed=42),
|
||||
temperature=0.0)
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# Stream via HTTPX since OpenAI translation client doesn't expose streaming
|
||||
server, model_name = server
|
||||
@@ -104,16 +109,14 @@ async def test_streaming_response(foscolo, client_and_model, server):
|
||||
foscolo.seek(0)
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
files = {"file": foscolo}
|
||||
async with http_client.stream("POST",
|
||||
url,
|
||||
headers=headers,
|
||||
data=data,
|
||||
files=files) as response:
|
||||
async with http_client.stream(
|
||||
"POST", url, headers=headers, data=data, files=files
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: "):]
|
||||
line = line[len("data: ") :]
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
chunk = json.loads(line)
|
||||
@@ -124,9 +127,10 @@ async def test_streaming_response(foscolo, client_and_model, server):
|
||||
# NOTE There's a small non-deterministic issue here, likely in the attn
|
||||
# computation, which will cause a few tokens to be different, while still
|
||||
# being very close semantically.
|
||||
assert sum([
|
||||
x == y for x, y in zip(res_stream, res_no_stream.text.split())
|
||||
]) >= len(res_stream) * 0.9
|
||||
assert (
|
||||
sum([x == y for x, y in zip(res_stream, res_no_stream.text.split())])
|
||||
>= len(res_stream) * 0.9
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -148,16 +152,14 @@ async def test_stream_options(foscolo, server):
|
||||
continuous = True
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
files = {"file": foscolo}
|
||||
async with http_client.stream("POST",
|
||||
url,
|
||||
headers=headers,
|
||||
data=data,
|
||||
files=files) as response:
|
||||
async with http_client.stream(
|
||||
"POST", url, headers=headers, data=data, files=files
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: "):]
|
||||
line = line[len("data: ") :]
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
chunk = json.loads(line)
|
||||
@@ -180,13 +182,14 @@ async def test_long_audio_request(foscolo, client_and_model):
|
||||
repeated_audio = np.tile(audio, 2)
|
||||
# Repeated audio to buffer
|
||||
buffer = io.BytesIO()
|
||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||
sf.write(buffer, repeated_audio, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
translation = await client.audio.translations.create(
|
||||
model=model_name,
|
||||
file=buffer,
|
||||
extra_body=dict(language="it", to_language="en"),
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip().lower()
|
||||
temperature=0.0,
|
||||
)
|
||||
out = json.loads(translation)["text"].strip().lower()
|
||||
assert out.count("greek sea") == 2
|
||||
|
||||
@@ -58,24 +58,18 @@ def base64_encoded_video() -> dict[str, str]:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_single_chat_session_video(client: openai.AsyncOpenAI,
|
||||
model_name: str, video_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_single_chat_session_video(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video_url", "video_url": {"url": video_url}},
|
||||
{"type": "text", "text": "What's in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -84,13 +78,15 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
temperature=0.0,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
|
||||
completion_tokens=10, prompt_tokens=6287, total_tokens=6297
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@@ -112,54 +108,44 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_error_on_invalid_video_url_type(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
video_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": video_url
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_error_on_invalid_video_url_type(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video_url", "video_url": video_url},
|
||||
{"type": "text", "text": "What's in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# video_url should be a dict {"url": "some url"}, not directly a string
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0)
|
||||
_ = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
video_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_single_chat_session_video_beamsearch(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video_url", "video_url": {"url": video_url}},
|
||||
{"type": "text", "text": "What's in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
@@ -168,36 +154,38 @@ async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
extra_body=dict(use_beam_search=True),
|
||||
)
|
||||
assert len(chat_completion.choices) == 2
|
||||
assert chat_completion.choices[
|
||||
0].message.content != chat_completion.choices[1].message.content
|
||||
assert (
|
||||
chat_completion.choices[0].message.content
|
||||
!= chat_completion.choices[1].message.content
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_single_chat_session_video_base64encoded(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str,
|
||||
base64_encoded_video: dict[str, str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url":
|
||||
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
video_url: str,
|
||||
base64_encoded_video: dict[str, str],
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -206,13 +194,15 @@ async def test_single_chat_session_video_base64encoded(
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
temperature=0.0,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
|
||||
completion_tokens=10, prompt_tokens=6287, total_tokens=6297
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@@ -236,58 +226,54 @@ async def test_single_chat_session_video_base64encoded(
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_single_chat_session_video_base64encoded_beamsearch(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str,
|
||||
base64_encoded_video: dict[str, str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url":
|
||||
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
video_url: str,
|
||||
base64_encoded_video: dict[str, str],
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
n=2,
|
||||
max_completion_tokens=10,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
extra_body=dict(use_beam_search=True),
|
||||
)
|
||||
assert len(chat_completion.choices) == 2
|
||||
assert chat_completion.choices[
|
||||
0].message.content != chat_completion.choices[1].message.content
|
||||
assert (
|
||||
chat_completion.choices[0].message.content
|
||||
!= chat_completion.choices[1].message.content
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
async def test_chat_streaming_video(client: openai.AsyncOpenAI,
|
||||
model_name: str, video_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_chat_streaming_video(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video_url", "video_url": {"url": video_url}},
|
||||
{"type": "text", "text": "What's in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -327,27 +313,23 @@ async def test_chat_streaming_video(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize(
|
||||
"video_urls",
|
||||
[TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))])
|
||||
async def test_multi_video_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
video_urls: list[str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*({
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": video_url
|
||||
}
|
||||
} for video_url in video_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this video?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
"video_urls", [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))]
|
||||
)
|
||||
async def test_multi_video_input(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_urls: list[str]
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
{"type": "video_url", "video_url": {"url": video_url}}
|
||||
for video_url in video_urls
|
||||
),
|
||||
{"type": "text", "text": "What's in this video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
if len(video_urls) > MAXIMUM_VIDEOS:
|
||||
with pytest.raises(openai.BadRequestError): # test multi-video input
|
||||
|
||||
@@ -71,26 +71,30 @@ async def client(server):
|
||||
@pytest.fixture(scope="session")
|
||||
def base64_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
return {
|
||||
image_asset:
|
||||
encode_image_base64(local_asset_server.get_image_asset(image_asset))
|
||||
image_asset: encode_image_base64(
|
||||
local_asset_server.get_image_asset(image_asset)
|
||||
)
|
||||
for image_asset in TEST_IMAGE_ASSETS
|
||||
}
|
||||
|
||||
|
||||
def get_hf_prompt_tokens(model_name, content, image_url):
|
||||
processor = AutoProcessor.from_pretrained(model_name,
|
||||
trust_remote_code=True,
|
||||
num_crops=4)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name, trust_remote_code=True, num_crops=4
|
||||
)
|
||||
|
||||
placeholder = "<|image_1|>\n"
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": f"{placeholder}{content}",
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{placeholder}{content}",
|
||||
}
|
||||
]
|
||||
images = [fetch_image(image_url)]
|
||||
|
||||
prompt = processor.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True)
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
inputs = processor(prompt, images, return_tensors="pt")
|
||||
|
||||
return inputs.input_ids.shape[1]
|
||||
@@ -99,25 +103,19 @@ def get_hf_prompt_tokens(model_name, content, image_url):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
|
||||
async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
||||
model_name: str, image_url: str):
|
||||
async def test_single_chat_session_image(
|
||||
client: openai.AsyncOpenAI, model_name: str, image_url: str
|
||||
):
|
||||
content_text = "What's in this image?"
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": content_text
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": content_text},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
max_completion_tokens = 10
|
||||
# test single completion
|
||||
@@ -127,17 +125,18 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
logprobs=True,
|
||||
temperature=0.0,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
|
||||
image_url)
|
||||
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url)
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=max_completion_tokens,
|
||||
prompt_tokens=hf_prompt_tokens,
|
||||
total_tokens=hf_prompt_tokens + max_completion_tokens)
|
||||
total_tokens=hf_prompt_tokens + max_completion_tokens,
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@@ -159,55 +158,45 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
|
||||
async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
image_url: str):
|
||||
async def test_error_on_invalid_image_url_type(
|
||||
client: openai.AsyncOpenAI, model_name: str, image_url: str
|
||||
):
|
||||
content_text = "What's in this image?"
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": image_url
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": content_text
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": image_url},
|
||||
{"type": "text", "text": content_text},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# image_url should be a dict {"url": "some url"}, not directly a string
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0)
|
||||
_ = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
|
||||
async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
image_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_single_chat_session_image_beamsearch(
|
||||
client: openai.AsyncOpenAI, model_name: str, image_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
@@ -216,10 +205,13 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
extra_body=dict(use_beam_search=True),
|
||||
)
|
||||
assert len(chat_completion.choices) == 2
|
||||
assert chat_completion.choices[
|
||||
0].message.content != chat_completion.choices[1].message.content
|
||||
assert (
|
||||
chat_completion.choices[0].message.content
|
||||
!= chat_completion.choices[1].message.content
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -227,27 +219,27 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS)
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
|
||||
async def test_single_chat_session_image_base64encoded(
|
||||
client: openai.AsyncOpenAI, model_name: str, raw_image_url: str,
|
||||
image_url: str, base64_encoded_image: dict[str, str]):
|
||||
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
raw_image_url: str,
|
||||
image_url: str,
|
||||
base64_encoded_image: dict[str, str],
|
||||
):
|
||||
content_text = "What's in this image?"
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url":
|
||||
f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": content_text
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": content_text},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
max_completion_tokens = 10
|
||||
# test single completion
|
||||
@@ -257,17 +249,18 @@ async def test_single_chat_session_image_base64encoded(
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
logprobs=True,
|
||||
temperature=0.0,
|
||||
top_logprobs=5)
|
||||
top_logprobs=5,
|
||||
)
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text,
|
||||
image_url)
|
||||
hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url)
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=max_completion_tokens,
|
||||
prompt_tokens=hf_prompt_tokens,
|
||||
total_tokens=hf_prompt_tokens + max_completion_tokens)
|
||||
total_tokens=hf_prompt_tokens + max_completion_tokens,
|
||||
)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@@ -291,36 +284,37 @@ async def test_single_chat_session_image_base64encoded(
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_ASSETS))))
|
||||
async def test_single_chat_session_image_base64encoded_beamsearch(
|
||||
client: openai.AsyncOpenAI, model_name: str, image_idx: int,
|
||||
base64_encoded_image: dict[str, str]):
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
image_idx: int,
|
||||
base64_encoded_image: dict[str, str],
|
||||
):
|
||||
# NOTE: This test also validates that we pass MM data through beam search
|
||||
raw_image_url = TEST_IMAGE_ASSETS[image_idx]
|
||||
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url":
|
||||
f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
n=2,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
extra_body=dict(use_beam_search=True),
|
||||
)
|
||||
assert len(chat_completion.choices) == 2
|
||||
for actual, expected_str in zip(chat_completion.choices, expected_res):
|
||||
assert actual.message.content == expected_str
|
||||
@@ -329,24 +323,18 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
|
||||
async def test_chat_streaming_image(client: openai.AsyncOpenAI,
|
||||
model_name: str, image_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
async def test_chat_streaming_image(
|
||||
client: openai.AsyncOpenAI, model_name: str, image_url: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -388,26 +376,23 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize(
|
||||
"image_urls",
|
||||
[TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))],
|
||||
indirect=True)
|
||||
async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
image_urls: list[str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
} for image_url in image_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
indirect=True,
|
||||
)
|
||||
async def test_multi_image_input(
|
||||
client: openai.AsyncOpenAI, model_name: str, image_urls: list[str]
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
for image_url in image_urls
|
||||
),
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
if len(image_urls) > MAXIMUM_IMAGES:
|
||||
with pytest.raises(openai.BadRequestError): # test multi-image input
|
||||
@@ -443,7 +428,8 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
@pytest.mark.parametrize(
|
||||
"image_urls",
|
||||
[TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))],
|
||||
indirect=True)
|
||||
indirect=True,
|
||||
)
|
||||
async def test_completions_with_image(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
@@ -452,13 +438,9 @@ async def test_completions_with_image(
|
||||
for image_url in image_urls:
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
@@ -468,7 +450,7 @@ async def test_completions_with_image(
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
}
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -485,7 +467,8 @@ async def test_completions_with_image(
|
||||
@pytest.mark.parametrize(
|
||||
"image_urls",
|
||||
[TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))],
|
||||
indirect=True)
|
||||
indirect=True,
|
||||
)
|
||||
async def test_completions_with_image_with_uuid(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
@@ -494,13 +477,9 @@ async def test_completions_with_image_with_uuid(
|
||||
for image_url in image_urls:
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
@@ -511,7 +490,7 @@ async def test_completions_with_image_with_uuid(
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
},
|
||||
"uuid": image_url
|
||||
"uuid": image_url,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -525,34 +504,25 @@ async def test_completions_with_image_with_uuid(
|
||||
# Second request, with empty image but the same uuid.
|
||||
chat_completion_with_empty_image = await client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe this image.",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {},
|
||||
"uuid": image_url
|
||||
},
|
||||
{"type": "image_url", "image_url": {}, "uuid": image_url},
|
||||
],
|
||||
},
|
||||
],
|
||||
model=model_name,
|
||||
)
|
||||
assert chat_completion_with_empty_image.choices[
|
||||
0].message.content is not None
|
||||
assert chat_completion_with_empty_image.choices[0].message.content is not None
|
||||
assert isinstance(
|
||||
chat_completion_with_empty_image.choices[0].message.content, str)
|
||||
assert len(
|
||||
chat_completion_with_empty_image.choices[0].message.content) > 0
|
||||
chat_completion_with_empty_image.choices[0].message.content, str
|
||||
)
|
||||
assert len(chat_completion_with_empty_image.choices[0].message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -564,13 +534,9 @@ async def test_completions_with_empty_image_with_uuid_without_cache_hit(
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
@@ -579,7 +545,7 @@ async def test_completions_with_empty_image_with_uuid_without_cache_hit(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {},
|
||||
"uuid": "uuid_not_previously_seen"
|
||||
"uuid": "uuid_not_previously_seen",
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -593,7 +559,8 @@ async def test_completions_with_empty_image_with_uuid_without_cache_hit(
|
||||
@pytest.mark.parametrize(
|
||||
"image_urls",
|
||||
[TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))],
|
||||
indirect=True)
|
||||
indirect=True,
|
||||
)
|
||||
async def test_completions_with_image_with_incorrect_uuid_format(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
@@ -602,13 +569,9 @@ async def test_completions_with_image_with_incorrect_uuid_format(
|
||||
for image_url in image_urls:
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
|
||||
@@ -6,8 +6,7 @@ import json
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
|
||||
Hermes2ProToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
@@ -27,61 +26,64 @@ SERVER_ARGS = [
|
||||
f"{LORA_MODEL}",
|
||||
]
|
||||
|
||||
TOOLS = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
PRODUCT_TOOLS = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_product_info",
|
||||
"description": "Get detailed information of a product based on its "
|
||||
"product ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inserted": {
|
||||
"type": "boolean",
|
||||
"description": "inserted.",
|
||||
},
|
||||
"product_id": {
|
||||
"type": "integer",
|
||||
"description": "The product ID of the product.",
|
||||
PRODUCT_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_product_info",
|
||||
"description": "Get detailed information of a product based on its "
|
||||
"product ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inserted": {
|
||||
"type": "boolean",
|
||||
"description": "inserted.",
|
||||
},
|
||||
"product_id": {
|
||||
"type": "integer",
|
||||
"description": "The product ID of the product.",
|
||||
},
|
||||
},
|
||||
"required": ["product_id", "inserted"],
|
||||
},
|
||||
"required": ["product_id", "inserted"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
||||
|
||||
PRODUCT_MESSAGES = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Hi! Do you have any detailed information about the product id "
|
||||
"7355608 and inserted true?",
|
||||
}]
|
||||
PRODUCT_MESSAGES = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! Do you have any detailed information about the product id "
|
||||
"7355608 and inserted true?",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -150,7 +152,8 @@ async def test_streaming_tool_call():
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments)
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
@@ -240,7 +243,8 @@ async def test_streaming_product_tool_call():
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments)
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
@@ -291,9 +295,7 @@ def test_hermes_parser_streaming_just_forward_text(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = (
|
||||
"""This is some prior text that has nothing to do with tool calling."""
|
||||
)
|
||||
text = """This is some prior text that has nothing to do with tool calling."""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
@@ -348,8 +350,9 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
delta_messages.append(delta)
|
||||
|
||||
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
|
||||
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||
for delta in delta_messages)
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == '{"trigger": true}'
|
||||
|
||||
|
||||
@@ -383,13 +386,13 @@ def test_hermes_parser_streaming(
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert (delta_messages[0].tool_calls[0].function.name ==
|
||||
"get_current_temperature")
|
||||
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||
for delta in delta_messages)
|
||||
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == (
|
||||
'{"location":"San Francisco, California, United States", '
|
||||
'"unit": "celsius"}')
|
||||
'{"location":"San Francisco, California, United States", "unit": "celsius"}'
|
||||
)
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
|
||||
@@ -8,15 +8,18 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction, run_tool_extraction_streaming)
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
|
||||
def make_tool_call(name, arguments):
|
||||
return ToolCall(type="function",
|
||||
function=FunctionCall(name=name,
|
||||
arguments=json.dumps(arguments)))
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=name, arguments=json.dumps(arguments)),
|
||||
)
|
||||
|
||||
|
||||
# TODO: add reason prefix and suffix.
|
||||
@@ -29,70 +32,68 @@ def make_tool_call(name, arguments):
|
||||
("How can I help you today?", [], "How can I help you today?"),
|
||||
# Single tool call, no content
|
||||
(
|
||||
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501
|
||||
[
|
||||
make_tool_call("get_weather", {
|
||||
"city": "San Francisco",
|
||||
"metric": "celsius"
|
||||
})
|
||||
],
|
||||
None),
|
||||
# Multiple tool calls
|
||||
(
|
||||
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501
|
||||
[
|
||||
make_tool_call("get_weather", {
|
||||
"city": "San Francisco",
|
||||
"metric": "celsius"
|
||||
}),
|
||||
make_tool_call(
|
||||
"register_user", {
|
||||
"name": "John Doe",
|
||||
"age": 37,
|
||||
"address": {
|
||||
"city": "San Francisco",
|
||||
"state": "CA"
|
||||
},
|
||||
"role": None,
|
||||
"passed_test": True,
|
||||
"aliases": ["John", "Johnny"]
|
||||
})
|
||||
],
|
||||
None),
|
||||
# Content before tool call
|
||||
(
|
||||
"I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
"I will call the tool now. "),
|
||||
# Content after tool call (should be stripped)
|
||||
(
|
||||
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Seattle"})],
|
||||
None),
|
||||
(
|
||||
"<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>",
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool",
|
||||
{"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"value": 123
|
||||
}
|
||||
}
|
||||
}})
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
])
|
||||
def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
# Multiple tool calls
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
),
|
||||
make_tool_call(
|
||||
"register_user",
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 37,
|
||||
"address": {"city": "San Francisco", "state": "CA"},
|
||||
"role": None,
|
||||
"passed_test": True,
|
||||
"aliases": ["John", "Johnny"],
|
||||
},
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
# Content before tool call
|
||||
(
|
||||
'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
"I will call the tool now. ",
|
||||
),
|
||||
# Content after tool call (should be stripped)
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Seattle"})],
|
||||
None,
|
||||
),
|
||||
(
|
||||
'<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>',
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hunyuan_a13b_tool_parser_extract(
|
||||
model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"hunyuan_a13b")(mock_tokenizer)
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=False)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
||||
mock_tokenizer
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=False
|
||||
)
|
||||
|
||||
# align the random id.
|
||||
for idx in range(len(tool_calls)):
|
||||
@@ -102,49 +103,74 @@ def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
|
||||
|
||||
|
||||
# Streaming test: simulate incremental output
|
||||
@pytest.mark.parametrize("model_deltas,expected_tool_calls", [
|
||||
([
|
||||
"<tool_calls>[{\"name\": \"get_weather\", ",
|
||||
"\"arguments\": {\"city\": \"San Francisco\", ",
|
||||
"\"metric\": \"celsius\"}}]", "</tool_calls>"
|
||||
], [
|
||||
make_tool_call("get_weather", {
|
||||
"city": "San Francisco",
|
||||
"metric": "celsius"
|
||||
})
|
||||
]),
|
||||
([
|
||||
"<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
|
||||
" {\"city\": \"Boston\"}", "}]", "</tool_calls>"
|
||||
], [make_tool_call("get_weather", {"city": "Boston"})]),
|
||||
([
|
||||
"", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
|
||||
" {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>"
|
||||
], [make_tool_call("get_weather", {"city": "Boston"})]),
|
||||
pytest.param([
|
||||
"<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ",
|
||||
" {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}",
|
||||
"]</tool_calls>"
|
||||
], [
|
||||
make_tool_call("complex_tool",
|
||||
{"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"value": 123
|
||||
}
|
||||
}
|
||||
}})
|
||||
@pytest.mark.parametrize(
|
||||
"model_deltas,expected_tool_calls",
|
||||
[
|
||||
(
|
||||
[
|
||||
'<tool_calls>[{"name": "get_weather", ',
|
||||
'"arguments": {"city": "San Francisco", ',
|
||||
'"metric": "celsius"}}]',
|
||||
"</tool_calls>",
|
||||
],
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
[
|
||||
'<tool_calls>[{"name":',
|
||||
' "get_weather",',
|
||||
' "arguments":',
|
||||
' {"city": "Boston"}',
|
||||
"}]",
|
||||
"</tool_calls>",
|
||||
],
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
),
|
||||
(
|
||||
[
|
||||
"",
|
||||
'<tool_calls>[{"name":',
|
||||
' "get_weather",',
|
||||
' "arguments":',
|
||||
' {"city": "Boston"}',
|
||||
"}]",
|
||||
"</tool_calls>",
|
||||
"\n</answer>",
|
||||
],
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
'<tool_calls>[{"name": "complex_tool",',
|
||||
' "arguments": ',
|
||||
' {"level1": {"level2": ',
|
||||
'{"level3": {"value": 123}}}}}',
|
||||
"]</tool_calls>",
|
||||
],
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(
|
||||
reason="stream parsing not support nested json yet."
|
||||
),
|
||||
),
|
||||
],
|
||||
marks=pytest.mark.xfail(
|
||||
reason="stream parsing not support nested json yet.")),
|
||||
])
|
||||
)
|
||||
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
|
||||
mock_tokenizer = MagicMock()
|
||||
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"hunyuan_a13b")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
||||
mock_tokenizer
|
||||
)
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_deltas, assert_one_tool_per_delta=False)
|
||||
tool_parser, model_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
# align the random id.
|
||||
for idx in range(len(reconstructor.tool_calls)):
|
||||
|
||||
@@ -5,8 +5,7 @@ import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
||||
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import (
|
||||
Llama3JsonToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -18,8 +17,10 @@ def parser():
|
||||
|
||||
def test_extract_tool_calls_simple(parser):
|
||||
# Test with a simple tool call
|
||||
model_output = ('Here is the result: {"name": "getOpenIncidentsTool", '
|
||||
'"parameters": {}} Would you like to know more?')
|
||||
model_output = (
|
||||
'Here is the result: {"name": "getOpenIncidentsTool", '
|
||||
'"parameters": {}} Would you like to know more?'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert isinstance(result, ExtractedToolCallInformation)
|
||||
@@ -34,8 +35,8 @@ def test_extract_tool_calls_simple(parser):
|
||||
def test_extract_tool_calls_with_arguments(parser):
|
||||
# Test with a tool call that has arguments
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test query", '
|
||||
'"limit": 10}}')
|
||||
'{"name": "searchTool", "parameters": {"query": "test query", "limit": 10}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
@@ -81,7 +82,8 @@ def test_extract_tool_calls_multiple_json(parser):
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}')
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
@@ -105,7 +107,8 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser):
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}} ; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}} ; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}')
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
@@ -118,11 +121,12 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser):
|
||||
def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
|
||||
# Test with multiple JSONs and surrounding text
|
||||
model_output = (
|
||||
'Here are the results: '
|
||||
"Here are the results: "
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}} '
|
||||
'Would you like to know more?')
|
||||
"Would you like to know more?"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
|
||||
@@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction, run_tool_extraction_streaming)
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
@@ -16,12 +18,14 @@ SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "LA", "metric": "C"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', "
|
||||
"age=9, "
|
||||
"address={'city': 'LA', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])]")
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"[register_user(name='Doe', "
|
||||
"age=9, "
|
||||
"address={'city': 'LA', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])]"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "Doe", '
|
||||
@@ -34,7 +38,7 @@ MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{}',
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
@@ -47,25 +51,28 @@ EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]")
|
||||
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
PYTHON_TAG_FUNCTION_OUTPUT = (
|
||||
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>")
|
||||
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
@@ -75,98 +82,139 @@ test_str = "<|python_start|>"
|
||||
test_str += "[get_weather(city='LA', metric='C'),"
|
||||
test_str += "register_user(name='Doe', age=9)]"
|
||||
TEST_CASES = [
|
||||
pytest.param(True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="simple_streaming"),
|
||||
pytest.param(False,
|
||||
SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming"),
|
||||
pytest.param(True,
|
||||
MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming"),
|
||||
pytest.param(False,
|
||||
MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming"),
|
||||
pytest.param(True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming"),
|
||||
pytest.param(False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming"),
|
||||
pytest.param(True,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming"),
|
||||
pytest.param(False,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming"),
|
||||
pytest.param(True,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming"),
|
||||
pytest.param(False,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming"),
|
||||
pytest.param(True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming"),
|
||||
pytest.param(False,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming"),
|
||||
pytest.param(
|
||||
True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming"
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MORE_TYPES_FUNCTION_OUTPUT,
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MORE_TYPES_FUNCTION_OUTPUT,
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT,
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT,
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT,
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT,
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_streaming"),
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_nonstreaming"),
|
||||
pytest.param(True,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_streaming"),
|
||||
pytest.param(False,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_nonstreaming"),
|
||||
pytest.param(True,
|
||||
test_str, [
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
],
|
||||
id="parallel_calls_streaming"),
|
||||
pytest.param(False,
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), " +
|
||||
"register_user(name='Doe', age=9)]", [
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
],
|
||||
id="parallel_calls_nonstreaming"),
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
test_str,
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||
+ "register_user(name='Doe', age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
|
||||
TEST_CASES)
|
||||
def test_tool_call(streaming: bool, model_output: str,
|
||||
expected_tool_calls: list[FunctionCall]):
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
@@ -176,8 +224,9 @@ def test_tool_call(streaming: bool, model_output: str,
|
||||
|
||||
def test_streaming_tool_call_with_large_steps():
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||
"get_weather(), "
|
||||
@@ -185,7 +234,8 @@ def test_streaming_tool_call_with_large_steps():
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False)
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
@@ -198,8 +248,9 @@ def test_streaming_tool_call_with_large_steps():
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
@@ -207,10 +258,10 @@ def test_regex_timeout_handling(streaming: bool):
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
fake_problematic_input,
|
||||
streaming=streaming)
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
|
||||
@@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction, run_tool_extraction_streaming)
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
@@ -22,7 +24,8 @@ MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])")
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "John Doe", '
|
||||
@@ -35,7 +38,7 @@ MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{}',
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
@@ -48,7 +51,8 @@ EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')")
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
@@ -59,80 +63,118 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
def test_no_tool_call(streaming: bool):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
mock_tokenizer)
|
||||
mock_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming"),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
|
||||
TEST_CASES)
|
||||
def test_tool_call(streaming: bool, model_output: str,
|
||||
expected_tool_calls: list[FunctionCall]):
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
mock_tokenizer)
|
||||
mock_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content is None
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
@@ -144,7 +186,8 @@ def test_tool_call(streaming: bool, model_output: str,
|
||||
def test_streaming_tool_call_with_large_steps():
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
mock_tokenizer)
|
||||
mock_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"[get_weather(city='San",
|
||||
" Francisco', metric='celsius'), "
|
||||
@@ -153,7 +196,8 @@ def test_streaming_tool_call_with_large_steps():
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False)
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
@@ -166,8 +210,9 @@ def test_streaming_tool_call_with_large_steps():
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
@@ -175,10 +220,10 @@ def test_regex_timeout_handling(streaming: bool):
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
fake_problematic_input,
|
||||
streaming=streaming)
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
|
||||
@@ -4,15 +4,17 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
|
||||
|
||||
class StreamingToolReconstructor:
|
||||
|
||||
def __init__(self, assert_one_tool_per_delta: bool = True):
|
||||
self.tool_calls: list[ToolCall] = []
|
||||
self.other_content: str = ""
|
||||
@@ -23,49 +25,60 @@ class StreamingToolReconstructor:
|
||||
self.other_content += delta.content
|
||||
else:
|
||||
assert delta.tool_calls, (
|
||||
"Streaming results should have either content or tool calls "
|
||||
"(or both)")
|
||||
"Streaming results should have either content or tool calls (or both)"
|
||||
)
|
||||
if self._assert_one_tool_per_delta:
|
||||
# Note: This isn't strictly required by the API and may not be
|
||||
# possible to adhere to depending on the token space and number of
|
||||
# tokens per streamed response from the model, but it is required
|
||||
# by tool_use tests, so we enforce it here by default also.
|
||||
assert len(delta.tool_calls) < 2, (
|
||||
"Streaming should include only one tool call per update.")
|
||||
"Streaming should include only one tool call per update."
|
||||
)
|
||||
for call_delta in delta.tool_calls:
|
||||
assert call_delta.type is None or call_delta.type == "function", (
|
||||
"Streaming tool calls should only emit function calls. Got "
|
||||
f"{call_delta.type}")
|
||||
current_tool_call = self.tool_calls[
|
||||
call_delta.index] if call_delta.index < len(
|
||||
self.tool_calls) else None
|
||||
f"{call_delta.type}"
|
||||
)
|
||||
current_tool_call = (
|
||||
self.tool_calls[call_delta.index]
|
||||
if call_delta.index < len(self.tool_calls)
|
||||
else None
|
||||
)
|
||||
if current_tool_call:
|
||||
assert (not call_delta.function.name), (
|
||||
assert not call_delta.function.name, (
|
||||
"Streaming tool calls should emit the full function name "
|
||||
f"exactly once. Got {call_delta.function.name}")
|
||||
assert (not call_delta.id), (
|
||||
f"exactly once. Got {call_delta.function.name}"
|
||||
)
|
||||
assert not call_delta.id, (
|
||||
"Streaming tool calls must emit function id only once. Got "
|
||||
f"{call_delta.id}")
|
||||
assert (call_delta.index == len(self.tool_calls) - 1), (
|
||||
f"{call_delta.id}"
|
||||
)
|
||||
assert call_delta.index == len(self.tool_calls) - 1, (
|
||||
f"Incorrect index for tool delta. Got {call_delta.index}, "
|
||||
f"expected {len(self.tool_calls) - 1}")
|
||||
current_tool_call.function.arguments += (
|
||||
call_delta.function.arguments)
|
||||
f"expected {len(self.tool_calls) - 1}"
|
||||
)
|
||||
current_tool_call.function.arguments += call_delta.function.arguments
|
||||
else:
|
||||
assert call_delta.id is not None, (
|
||||
"Streaming tool calls must have an id on first appearance")
|
||||
"Streaming tool calls must have an id on first appearance"
|
||||
)
|
||||
assert call_delta.function.name is not None, (
|
||||
"Streaming tool calls must have a function name on first "
|
||||
"appearance")
|
||||
"Streaming tool calls must have a function name on first appearance"
|
||||
)
|
||||
assert call_delta.index == len(self.tool_calls), (
|
||||
f"Incorrect index for tool delta. Got {call_delta.index}, "
|
||||
f"expected {len(self.tool_calls)}")
|
||||
f"expected {len(self.tool_calls)}"
|
||||
)
|
||||
self.tool_calls.append(
|
||||
ToolCall(id=call_delta.id,
|
||||
function=FunctionCall(
|
||||
name=call_delta.function.name,
|
||||
arguments=call_delta.function.arguments
|
||||
or "")))
|
||||
ToolCall(
|
||||
id=call_delta.id,
|
||||
function=FunctionCall(
|
||||
name=call_delta.function.name,
|
||||
arguments=call_delta.function.arguments or "",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_tool_extraction(
|
||||
@@ -80,11 +93,11 @@ def run_tool_extraction(
|
||||
tool_parser,
|
||||
model_output,
|
||||
request,
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta)
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta,
|
||||
)
|
||||
return reconstructor.other_content or None, reconstructor.tool_calls
|
||||
else:
|
||||
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output,
|
||||
request)
|
||||
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, request)
|
||||
assert extracted.tools_called == bool(extracted.tool_calls)
|
||||
return extracted.content, extracted.tool_calls
|
||||
|
||||
@@ -92,7 +105,7 @@ def run_tool_extraction(
|
||||
def run_tool_extraction_nonstreaming(
|
||||
tool_parser: ToolParser,
|
||||
model_output: str,
|
||||
request: Union[ChatCompletionRequest, None] = None
|
||||
request: Union[ChatCompletionRequest, None] = None,
|
||||
) -> ExtractedToolCallInformation:
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
return tool_parser.extract_tool_calls(model_output, request)
|
||||
@@ -106,7 +119,8 @@ def run_tool_extraction_streaming(
|
||||
) -> StreamingToolReconstructor:
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
reconstructor = StreamingToolReconstructor(
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta)
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta
|
||||
)
|
||||
previous_text = ""
|
||||
previous_tokens: list[int] = []
|
||||
for delta in model_deltas:
|
||||
@@ -118,8 +132,14 @@ def run_tool_extraction_streaming(
|
||||
current_text = previous_text + delta
|
||||
current_tokens = previous_tokens + token_delta
|
||||
delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_text, current_text, delta, previous_tokens,
|
||||
current_tokens, token_delta, request)
|
||||
previous_text,
|
||||
current_text,
|
||||
delta,
|
||||
previous_tokens,
|
||||
current_tokens,
|
||||
token_delta,
|
||||
request,
|
||||
)
|
||||
if delta_message is not None:
|
||||
reconstructor.append_delta(delta_message)
|
||||
previous_text = current_text
|
||||
|
||||
Reference in New Issue
Block a user