[Refactor] Relocate completion and chat completion tests (#37125)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2026-03-16 23:31:23 -04:00
committed by GitHub
parent f04d5226f8
commit 384dc7f77b
26 changed files with 41 additions and 48 deletions

View File

@@ -0,0 +1,266 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any
from unittest.mock import MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import GenerationError
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
model = MODEL_NAME
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@dataclass
class MockParallelConfig:
_api_process_rank: int = 0
@dataclass
class MockVllmConfig:
model_config: MockModelConfig
parallel_config: MockParallelConfig
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
io_processor=engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
return OpenAIServingCompletion(
engine,
models,
openai_serving_render=serving_render,
request_logger=None,
)
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
@pytest.mark.asyncio
async def test_completion_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
stream=False,
)
with pytest.raises(GenerationError):
await serving_completion.create_completion(request)
@pytest.mark.asyncio
async def test_completion_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
completion_output_1 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
)
request_output_1 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_1],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
completion_output_2 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output_2 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_2],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output_1
yield request_output_2
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
stream=True,
)
response = await serving_completion.create_completion(request)
chunks = []
async for chunk in response:
chunks.append(chunk)
assert len(chunks) >= 2
assert any("Internal server error" in chunk for chunk in chunks), (
f"Expected error message in chunks: {chunks}"
)
assert chunks[-1] == "data: [DONE]\n\n"
def test_json_schema_response_format_missing_schema():
"""When response_format type is 'json_schema' but the json_schema field
is not provided, request construction should raise a validation error
so the API returns 400 instead of 500."""
with pytest.raises(Exception, match="json_schema.*must be provided"):
CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
response_format={"type": "json_schema"},
)
def test_negative_prompt_token_ids_nested():
"""Negative token IDs in prompt (nested list) should raise validation error."""
with pytest.raises(Exception, match="greater than or equal to 0"):
CompletionRequest(
model=MODEL_NAME,
prompt=[[-1]],
max_tokens=10,
)
def test_negative_prompt_token_ids_flat():
"""Negative token IDs in prompt (flat list) should raise validation error."""
with pytest.raises(Exception, match="greater than or equal to 0"):
CompletionRequest(
model=MODEL_NAME,
prompt=[-1],
max_tokens=10,
)

View File

@@ -0,0 +1,307 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import json
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import torch
# downloading lora to test lora requests
from openai import BadRequestError
from transformers import AutoConfig
from tests.utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
LORA_SERVING_MODEL_NAME = "opt125m-lora"
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module", params=["use-lora"])
def default_server_args(
request: pytest.FixtureRequest, opt125_lora_files: str
) -> list[str]:
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
# Prompt Embeds server args
"--enable-prompt-embeds",
]
if request.param == "use-lora":
lora_module_1 = {
"name": LORA_SERVING_MODEL_NAME,
"path": opt125_lora_files,
"base_model_name": MODEL_NAME,
}
args.extend(
[
"--enable-lora",
"--lora-module",
json.dumps(lora_module_1),
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
]
)
return args
EXAMPLE_PROMPTS = [
"Hello, my name is",
"What is an LLM?",
]
def _encode_embeds(embeds: torch.Tensor):
buffer = io.BytesIO()
torch.save(embeds, buffer)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
@pytest.fixture(scope="module")
def example_prompt_embeds(hf_runner):
"""Create example embeddings and return them as base64 encoded string."""
with hf_runner(MODEL_NAME) as hf_model:
example_embeddings = hf_model.get_prompt_embeddings(EXAMPLE_PROMPTS)
return [_encode_embeds(item) for item in example_embeddings]
@pytest.fixture(scope="module", params=["", "--disable-frontend-multiprocessing"])
def server_with_prompt_embeds(default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client_with_prompt_embeds(server_with_prompt_embeds):
async with server_with_prompt_embeds.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
async def test_completions_with_prompt_embeds(
example_prompt_embeds,
client_with_prompt_embeds: openai.AsyncOpenAI,
model_name: str,
):
encoded_embeds, encoded_embeds2 = example_prompt_embeds
# Test case: Single prompt embeds input
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
)
assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None
# Test case: batch completion with prompt_embeds
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
)
assert len(completion.choices) == 2
assert len(completion.choices[0].text) >= 1
assert len(completion.choices[1].text) >= 1
# Test case: streaming with prompt_embeds
single_completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
)
single_output = single_completion.choices[0].text
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt=None,
max_tokens=5,
temperature=0.0,
stream=True,
extra_body={"prompt_embeds": encoded_embeds},
)
chunks = []
finish_reason_count = 0
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert "".join(chunks) == single_output
# Test case: batch streaming with prompt_embeds
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt=None,
max_tokens=5,
temperature=0.0,
stream=True,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
)
chunks_stream_embeds: list[list[str]] = [[], []]
finish_reason_count = 0
async for chunk in stream:
chunks_stream_embeds[chunk.choices[0].index].append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == 2
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert len(chunks_stream_embeds[0]) > 0
assert len(chunks_stream_embeds[1]) > 0
# Test case: mixed text and prompt_embeds
completion_mixed = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="This is a prompt",
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
)
assert len(completion.choices) == 2
completion_text_only = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="This is a prompt",
max_tokens=5,
temperature=0.0,
)
completion_embeds_only = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
)
# Embeddings responses should be handled first
assert completion_mixed.choices[0].text == completion_embeds_only.choices[0].text
assert completion_mixed.choices[1].text == completion_text_only.choices[0].text
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
async def test_completions_errors_with_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str
):
# Test error case: invalid prompt_embeds
with pytest.raises(BadRequestError):
await client_with_prompt_embeds.completions.create(
prompt=None,
model=model_name,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": "invalid_base64"},
)
@pytest.mark.asyncio
@pytest.mark.parametrize("logprobs_arg", [1, 0])
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
async def test_completions_with_logprobs_and_prompt_embeds(
example_prompt_embeds,
client_with_prompt_embeds: openai.AsyncOpenAI,
logprobs_arg: int,
model_name: str,
):
encoded_embeds, encoded_embeds2 = example_prompt_embeds
# Test case: Logprobs using prompt_embeds
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt=None,
max_tokens=5,
temperature=0.0,
echo=False,
logprobs=logprobs_arg,
extra_body={"prompt_embeds": encoded_embeds},
)
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) == 5
assert len(logprobs.token_logprobs) == 5
assert len(logprobs.top_logprobs) == 5
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) == 5
# Test case: Log probs with batch completion and prompt_embeds
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt=None,
max_tokens=5,
temperature=0.0,
echo=False,
logprobs=logprobs_arg,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
)
assert len(completion.choices) == 2
for choice in completion.choices:
logprobs = choice.logprobs
assert logprobs is not None
assert len(logprobs.text_offset) == 5
assert len(logprobs.token_logprobs) == 5
assert len(logprobs.top_logprobs) == 5
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) == 5
@pytest.mark.asyncio
async def test_prompt_logprobs_raises_error(
example_prompt_embeds,
client_with_prompt_embeds: openai.AsyncOpenAI,
):
encoded_embeds, _ = example_prompt_embeds
with pytest.raises(BadRequestError, match="not compatible"):
await client_with_prompt_embeds.completions.create(
model=MODEL_NAME,
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True},
)
@pytest.mark.asyncio
async def test_empty_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI,
) -> None:
await client_with_prompt_embeds.completions.create(
model=MODEL_NAME,
prompt="Hello",
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": []},
)

View File

@@ -0,0 +1,259 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import suppress
from dataclasses import dataclass, field
from http import HTTPStatus
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
MOCK_RESOLVER_NAME = "mock_test_resolver"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
"""Minimal mock ModelConfig for testing."""
model: str = MODEL_NAME
runner_type = "generate"
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
tokenizer_mode: str = "auto"
max_model_len: int = 100
tokenizer_revision: str | None = None
multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
@dataclass
class MockParallelConfig:
_api_process_rank: int = 0
@dataclass
class MockVllmConfig:
model_config: MockModelConfig
parallel_config: MockParallelConfig
class MockLoRAResolver(LoRAResolver):
async def resolve_lora(
self, base_model_name: str, lora_name: str
) -> LoRARequest | None:
if lora_name == "test-lora":
return LoRARequest(
lora_name="test-lora",
lora_int_id=1,
lora_path="/fake/path/test-lora",
)
elif lora_name == "invalid-lora":
return LoRARequest(
lora_name="invalid-lora",
lora_int_id=2,
lora_path="/fake/path/invalid-lora",
)
return None
@pytest.fixture(autouse=True)
def register_mock_resolver():
"""Fixture to register and unregister the mock LoRA resolver."""
resolver = MockLoRAResolver()
LoRAResolverRegistry.register_resolver(MOCK_RESOLVER_NAME, resolver)
yield
# Cleanup: remove the resolver after the test runs
if MOCK_RESOLVER_NAME in LoRAResolverRegistry.resolvers:
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer.from_config(
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
@pytest.fixture
def mock_serving_setup():
"""Provides a mocked engine and serving completion instance."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
async def mock_add_lora_side_effect(lora_request: LoRARequest):
"""Simulate engine behavior when adding LoRAs."""
if lora_request.lora_name == "test-lora":
# Simulate successful addition
return True
if lora_request.lora_name == "invalid-lora":
# Simulate failure during addition (e.g. invalid format)
raise ValueError(f"Simulated failure adding LoRA: {lora_request.lora_name}")
return True
mock_engine.add_lora = AsyncMock(side_effect=mock_add_lora_side_effect)
async def mock_generate(*args, **kwargs):
for _ in []:
yield _
mock_engine.generate = MagicMock(spec=AsyncLLM.generate, side_effect=mock_generate)
mock_engine.generate.reset_mock()
mock_engine.add_lora.reset_mock()
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(
engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_render = OpenAIServingRender(
model_config=mock_engine.model_config,
renderer=mock_engine.renderer,
io_processor=mock_engine.io_processor,
model_registry=models.registry,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
serving_completion = OpenAIServingCompletion(
mock_engine, models, openai_serving_render=serving_render, request_logger=None
)
return mock_engine, serving_completion
@pytest.mark.asyncio
async def test_serving_completion_with_lora_resolver(mock_serving_setup, monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
mock_engine, serving_completion = mock_serving_setup
lora_model_name = "test-lora"
req_found = CompletionRequest(
model=lora_model_name,
prompt="Generate with LoRA",
)
# Suppress potential errors during the mocked generate call,
# as we are primarily checking for add_lora and generate calls
with suppress(Exception):
await serving_completion.create_completion(req_found)
mock_engine.add_lora.assert_awaited_once()
called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == lora_model_name
mock_engine.generate.assert_called_once()
called_lora_request = mock_engine.generate.call_args[1]["lora_request"]
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == lora_model_name
@pytest.mark.asyncio
async def test_serving_completion_resolver_not_found(mock_serving_setup, monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
mock_engine, serving_completion = mock_serving_setup
non_existent_model = "non-existent-lora-adapter"
req = CompletionRequest(
model=non_existent_model,
prompt="what is 1+1?",
)
response = await serving_completion.create_completion(req)
mock_engine.add_lora.assert_not_awaited()
mock_engine.generate.assert_not_called()
assert isinstance(response, ErrorResponse)
assert response.error.code == HTTPStatus.NOT_FOUND.value
assert non_existent_model in response.error.message
@pytest.mark.asyncio
async def test_serving_completion_resolver_add_lora_fails(
mock_serving_setup, monkeypatch
):
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
mock_engine, serving_completion = mock_serving_setup
invalid_model = "invalid-lora"
req = CompletionRequest(
model=invalid_model,
prompt="what is 1+1?",
)
response = await serving_completion.create_completion(req)
# Assert add_lora was called before the failure
mock_engine.add_lora.assert_awaited_once()
called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == invalid_model
# Assert generate was *not* called due to the failure
mock_engine.generate.assert_not_called()
# Assert the correct error response
assert isinstance(response, ErrorResponse)
assert response.error.code == HTTPStatus.BAD_REQUEST.value
assert invalid_model in response.error.message
@pytest.mark.asyncio
async def test_serving_completion_flag_not_set(mock_serving_setup):
mock_engine, serving_completion = mock_serving_setup
lora_model_name = "test-lora"
req_found = CompletionRequest(
model=lora_model_name,
prompt="Generate with LoRA",
)
await serving_completion.create_completion(req_found)
mock_engine.add_lora.assert_not_called()
mock_engine.generate.assert_not_called()

View File

@@ -0,0 +1,113 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from unittest.mock import Mock
# imports for structured outputs tests
import openai
import pybase64
import pytest
import regex as re
import torch
from tests.utils import RemoteOpenAIServer
from vllm.config import ModelConfig
from vllm.renderers.embed_utils import safe_load_prompt_embeds
@pytest.mark.asyncio
async def test_empty_prompt():
model_name = "gpt2"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
with pytest.raises(
openai.BadRequestError,
match="Either prompt or prompt_embeds must be provided and non-empty.",
):
await client.completions.create(
model=model_name,
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": []},
)
@pytest.mark.asyncio
async def test_out_of_vocab_token_ids():
model_name = "gpt2"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
with pytest.raises(
openai.BadRequestError, match=re.compile(".*out of vocabulary.*").pattern
):
await client.completions.create(
model=model_name, prompt=[999999], max_tokens=5, temperature=0.0
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"layout", [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr]
)
@pytest.mark.parametrize("seq_len", [2, 10])
@pytest.mark.parametrize("hidden_size", [2, 10])
def test_load_prompt_embeds(
dtype: torch.dtype, layout: torch.layout, seq_len: int, hidden_size: int
):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = True
# construct arbitrary tensors of various dtypes, layouts, and sizes.
# We need to check against different layouts to make sure that if a user
# uses sparse tensors to reduce the transmission size of prompt embeddings,
# we must cast them to dense/strided before passing them into the engine.
# We don't use non-CPU tensors in this test to avoid preemptively
# initializing cuda and break other tests in the suite that fork processes.
# We also need to make sure that we only use devices that are actually
# available in the environment the test is running on. For simplicity,
# we just test against CPU.
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
if layout == torch.strided:
tensor = tensor.contiguous()
elif layout == torch.sparse_coo:
tensor = tensor.to_sparse_coo()
elif layout == torch.sparse_csc:
tensor = tensor.to_sparse_csc()
elif layout == torch.sparse_csr:
tensor = tensor.to_sparse_csr()
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())
loaded_tensor = safe_load_prompt_embeds(model_config, encoded_tensor)
assert loaded_tensor.device.type == "cpu"
assert loaded_tensor.layout == torch.strided
torch.testing.assert_close(
loaded_tensor, tensor.to("cpu").to_dense(), equal_nan=True
)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("seq_len", [2])
@pytest.mark.parametrize("hidden_size", [2])
def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = False
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())
with pytest.raises(ValueError, match="--enable-prompt-embeds"):
safe_load_prompt_embeds(model_config, encoded_tensor)

View File

@@ -0,0 +1,564 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for shutdown behavior, timeout, and signal handling."""
import asyncio
import signal
import subprocess
import sys
import time
from dataclasses import dataclass, field
import httpx
import openai
import psutil
import pytest
from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
# GPU initialization might take take longer
_IS_ROCM = current_platform.is_rocm()
_SERVER_STARTUP_TIMEOUT = 120
_PROCESS_EXIT_TIMEOUT = 15
_SHUTDOWN_DETECTION_TIMEOUT = 10
_CHILD_CLEANUP_TIMEOUT = 10
def _get_child_pids(parent_pid: int) -> list[int]:
try:
parent = psutil.Process(parent_pid)
return [c.pid for c in parent.children(recursive=True)]
except psutil.NoSuchProcess:
return []
async def _assert_children_cleaned_up(
child_pids: list[int],
timeout: float = _CHILD_CLEANUP_TIMEOUT,
):
"""Wait for child processes to exit and fail if any remain."""
if not child_pids:
return
deadline = time.time() + timeout
while time.time() < deadline:
still_alive = []
for pid in child_pids:
try:
p = psutil.Process(pid)
if p.is_running() and p.status() != psutil.STATUS_ZOMBIE:
still_alive.append(pid)
except psutil.NoSuchProcess:
pass
if not still_alive:
return
await asyncio.sleep(0.5)
pytest.fail(
f"Child processes {still_alive} still alive after {timeout}s. "
f"Process cleanup may not be working correctly."
)
@dataclass
class ShutdownState:
got_503: bool = False
got_500: bool = False
requests_after_sigterm: int = 0
aborted_requests: int = 0
connection_errors: int = 0
stop_requesting: bool = False
errors: list[str] = field(default_factory=list)
async def _concurrent_request_loop(
client: openai.AsyncOpenAI,
state: ShutdownState,
sigterm_sent: asyncio.Event | None = None,
concurrency: int = 10,
):
"""Run multiple concurrent requests to keep the server busy."""
async def single_request():
while not state.stop_requesting:
try:
response = await client.completions.create(
model=MODEL_NAME,
prompt="Write a story: ",
max_tokens=200,
)
if sigterm_sent is not None and sigterm_sent.is_set():
state.requests_after_sigterm += 1
# Check if any choice has finish_reason='abort'
if any(choice.finish_reason == "abort" for choice in response.choices):
state.aborted_requests += 1
except openai.APIStatusError as e:
if e.status_code == 503:
state.got_503 = True
elif e.status_code == 500:
state.got_500 = True
else:
state.errors.append(f"API error: {e}")
except (openai.APIConnectionError, httpx.RemoteProtocolError):
state.connection_errors += 1
if sigterm_sent is not None and sigterm_sent.is_set():
break
except Exception as e:
state.errors.append(f"Unexpected error: {e}")
break
await asyncio.sleep(0.01)
tasks = [asyncio.create_task(single_request()) for _ in range(concurrency)]
try:
await asyncio.gather(*tasks, return_exceptions=True)
finally:
for t in tasks:
if not t.done():
t.cancel()
@pytest.mark.asyncio
async def test_shutdown_on_engine_failure():
"""Verify that API returns connection error when server process is killed.
Starts a vLLM server, kills it to simulate a crash, then verifies that
subsequent API calls fail appropriately.
"""
port = get_open_port()
proc = subprocess.Popen(
[
# dtype, max-len etc set so that this can run in CI
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
MODEL_NAME,
"--dtype",
"bfloat16",
"--max-model-len",
"128",
"--enforce-eager",
"--port",
str(port),
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"2",
"--disable-frontend-multiprocessing",
],
# ROCm: Disable stdout/stderr pipe capture. Subprocess hangs when
# stdout/stderr pipes are enabled during ROCm GPU initialization.
stdout=None if _IS_ROCM else subprocess.PIPE,
stderr=None if _IS_ROCM else subprocess.PIPE,
text=None if _IS_ROCM else True,
preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN),
)
# Wait for server startup
start_time = time.time()
client = openai.AsyncOpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="dummy",
max_retries=0,
timeout=10,
)
# Poll until server is ready
while time.time() - start_time < _SERVER_STARTUP_TIMEOUT:
try:
await client.completions.create(
model=MODEL_NAME, prompt="Hello", max_tokens=1
)
break
except Exception:
time.sleep(0.5)
if proc.poll() is not None:
if _IS_ROCM:
pytest.fail(f"Server died during startup: {proc.returncode}")
else:
stdout, stderr = proc.communicate(timeout=1)
pytest.fail(
f"Server died during startup. "
f"stdout: {stdout}, stderr: {stderr}"
)
else:
proc.terminate()
proc.wait(timeout=_PROCESS_EXIT_TIMEOUT)
pytest.fail(f"Server failed to start in {_SERVER_STARTUP_TIMEOUT} seconds")
# Kill server to simulate crash
proc.terminate()
time.sleep(1)
# Verify API calls now fail
with pytest.raises((openai.APIConnectionError, openai.APIStatusError)):
await client.completions.create(
model=MODEL_NAME, prompt="This should fail", max_tokens=1
)
return_code = proc.wait(timeout=_PROCESS_EXIT_TIMEOUT)
assert return_code is not None
@pytest.mark.asyncio
async def test_wait_timeout_completes_requests():
"""Verify wait timeout: new requests rejected, in-flight requests complete."""
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"30",
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
state = ShutdownState()
sigterm_sent = asyncio.Event()
request_task = asyncio.create_task(
_concurrent_request_loop(client, state, sigterm_sent, concurrency=10)
)
await asyncio.sleep(0.5)
proc.send_signal(signal.SIGTERM)
sigterm_sent.set()
try:
await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT)
except asyncio.TimeoutError:
pass
finally:
state.stop_requesting = True
if not request_task.done():
request_task.cancel()
await asyncio.gather(request_task, return_exceptions=True)
# wait timeout should complete in-flight requests
assert state.requests_after_sigterm > 0, (
f"Wait timeout should complete in-flight requests. "
f"503: {state.got_503}, 500: {state.got_500}, "
f"conn_errors: {state.connection_errors}, errors: {state.errors}"
)
# server must stop accepting new requests (503, 500, or connection close)
assert state.got_503 or state.got_500 or state.connection_errors > 0, (
f"Server should stop accepting requests. "
f"completed: {state.requests_after_sigterm}, errors: {state.errors}"
)
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
@pytest.mark.parametrize("wait_for_engine_idle", [0.0, 2.0])
async def test_abort_timeout_exits_quickly(wait_for_engine_idle: float):
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"0",
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
if wait_for_engine_idle > 0:
client = remote_server.get_async_client()
# Send requests to ensure engine is fully initialized
for _ in range(2):
await client.completions.create(
model=MODEL_NAME,
prompt="Test request: ",
max_tokens=10,
)
# Wait for engine to become idle
await asyncio.sleep(wait_for_engine_idle)
start_time = time.time()
proc.send_signal(signal.SIGTERM)
# abort timeout (0) should exit promptly
for _ in range(20):
if proc.poll() is not None:
break
time.sleep(0.1)
if proc.poll() is None:
proc.kill()
proc.wait(timeout=5)
pytest.fail("Process did not exit after SIGTERM with abort timeout")
exit_time = time.time() - start_time
assert exit_time < 2, f"Default shutdown took too long: {exit_time:.1f}s"
assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}"
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
async def test_wait_timeout_with_short_duration():
"""Verify server exits cleanly with a short wait timeout."""
wait_timeout = 3
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
str(wait_timeout),
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
state = ShutdownState()
request_task = asyncio.create_task(
_concurrent_request_loop(client, state, concurrency=3)
)
await asyncio.sleep(0.5)
start_time = time.time()
proc.send_signal(signal.SIGTERM)
# server should exit within wait_timeout + buffer
max_wait = wait_timeout + 15
for _ in range(int(max_wait * 10)):
if proc.poll() is not None:
break
time.sleep(0.1)
exit_time = time.time() - start_time
state.stop_requesting = True
if not request_task.done():
request_task.cancel()
await asyncio.gather(request_task, return_exceptions=True)
if proc.poll() is None:
proc.kill()
proc.wait(timeout=5)
pytest.fail(f"Process did not exit within {max_wait}s after SIGTERM")
assert exit_time < wait_timeout + 10, (
f"Took too long to exit ({exit_time:.1f}s), expected <{wait_timeout + 10}s"
)
assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}"
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
async def test_abort_timeout_fails_inflight_requests():
"""Verify abort timeout (0) immediately aborts in-flight requests."""
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"0",
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
state = ShutdownState()
sigterm_sent = asyncio.Event()
request_task = asyncio.create_task(
_concurrent_request_loop(client, state, sigterm_sent, concurrency=10)
)
await asyncio.sleep(0.5)
proc.send_signal(signal.SIGTERM)
sigterm_sent.set()
try:
await asyncio.wait_for(request_task, timeout=5)
except asyncio.TimeoutError:
pass
finally:
state.stop_requesting = True
if not request_task.done():
request_task.cancel()
await asyncio.gather(request_task, return_exceptions=True)
# With abort timeout (0), requests should be aborted (finish_reason='abort')
# or rejected (connection errors or API errors)
assert (
state.aborted_requests > 0
or state.connection_errors > 0
or state.got_500
or state.got_503
), (
f"Abort timeout should cause request aborts or failures. "
f"aborted: {state.aborted_requests}, "
f"503: {state.got_503}, 500: {state.got_500}, "
f"conn_errors: {state.connection_errors}, "
f"completed: {state.requests_after_sigterm}"
)
# Verify fast shutdown
start_time = time.time()
for _ in range(100):
if proc.poll() is not None:
break
time.sleep(0.1)
exit_time = time.time() - start_time
assert exit_time < 10, f"Abort timeout shutdown took too long: {exit_time:.1f}s"
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
async def test_request_rejection_during_shutdown():
"""Verify new requests are rejected with error during shutdown."""
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"30",
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
proc.send_signal(signal.SIGTERM)
await asyncio.sleep(1.0)
# Try to send new requests - they should be rejected
rejected_count = 0
for _ in range(10):
try:
await client.completions.create(
model=MODEL_NAME, prompt="Hello", max_tokens=10
)
except (
openai.APIStatusError,
openai.APIConnectionError,
httpx.RemoteProtocolError,
):
rejected_count += 1
await asyncio.sleep(0.1)
assert rejected_count > 0, (
f"Expected requests to be rejected during shutdown, "
f"but {rejected_count} were rejected out of 10"
)
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
async def test_multi_api_server_shutdown():
"""Verify shutdown works with multiple API servers."""
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"30",
"--api-server-count",
"2",
]
with RemoteOpenAIServer(MODEL_NAME, server_args, auto_port=True) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
assert len(child_pids) >= 2, (
f"Expected at least 2 child processes, got {len(child_pids)}"
)
state = ShutdownState()
sigterm_sent = asyncio.Event()
# Start concurrent requests across both API servers
request_task = asyncio.create_task(
_concurrent_request_loop(client, state, sigterm_sent, concurrency=8)
)
await asyncio.sleep(0.5)
# Send SIGTERM to parent - should propagate to all children
proc.send_signal(signal.SIGTERM)
sigterm_sent.set()
try:
await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT)
except asyncio.TimeoutError:
pass
finally:
state.stop_requesting = True
if not request_task.done():
request_task.cancel()
await asyncio.gather(request_task, return_exceptions=True)
for _ in range(300): # up to 30 seconds
if proc.poll() is not None:
break
time.sleep(0.1)
if proc.poll() is None:
proc.kill()
proc.wait(timeout=5)
pytest.fail("Process did not exit after SIGTERM")
await _assert_children_cleaned_up(child_pids)

View File

@@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import os
import tempfile
import openai
import pytest
import pytest_asyncio
import torch.cuda
from tests.utils import RemoteOpenAIServer
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig,
tensorize_lora_adapter,
tensorize_vllm_model,
)
from vllm.platforms import current_platform
MODEL_NAME = "unsloth/llama-3.2-1b-Instruct"
LORA_PATH = "davzoku/finqa_adapter_1b"
def _cleanup():
gc.collect()
torch.accelerator.empty_cache()
@pytest.fixture(autouse=True)
def cleanup():
_cleanup()
@pytest.fixture(scope="module")
def tmp_dir():
with tempfile.TemporaryDirectory() as path:
yield path
@pytest.fixture(scope="module")
def model_uri(tmp_dir):
yield f"{tmp_dir}/model.tensors"
@pytest.fixture(scope="module")
def tensorize_model_and_lora(tmp_dir, model_uri):
tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, lora_dir=tmp_dir)
args = EngineArgs(model=MODEL_NAME)
tensorize_lora_adapter(LORA_PATH, tensorizer_config)
tensorize_vllm_model(args, tensorizer_config)
# Manually invoke a _cleanup() here, as the cleanup()
# fixture won't be guaranteed to be called after this
# when this fixture is used for a test
_cleanup()
yield
@pytest.fixture(scope="module")
def server(model_uri, tensorize_model_and_lora):
# In this case, model_uri is a directory with a model.tensors
# file and all necessary model artifacts, particularly a
# HF `config.json` file. In this case, Tensorizer can infer the
# `TensorizerConfig` so --model-loader-extra-config can be completely
# omitted.
## Start OpenAI API server
args = [
"--load-format",
"tensorizer",
"--served-model-name",
MODEL_NAME,
"--enable-lora",
]
if current_platform.is_rocm():
args += ["--attention-backend", "TRITON_ATTN"]
model_dir = os.path.dirname(model_uri)
with RemoteOpenAIServer(model_dir, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
_cleanup()
completion = await client.completions.create(
model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0
)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.model == MODEL_NAME
assert len(completion.choices) == 1
assert len(completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11
)

View File

@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import tempfile
import pytest
from tests.utils import RemoteOpenAIServer
from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf
from vllm.tokenizers import get_tokenizer
MODEL_NAME = "Qwen/Qwen3-0.6B"
MODEL_PATH = os.path.join(tempfile.gettempdir(), "qwen3_06b")
@pytest.fixture(scope="module")
def server():
global MODEL_PATH
MODEL_PATH = download_weights_from_hf(
MODEL_NAME,
allow_patterns=["*"],
cache_dir=MODEL_PATH,
ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"],
)
args = [
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
"--skip-tokenizer-init",
"--load-format",
"dummy",
]
with RemoteOpenAIServer(MODEL_PATH, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
async def test_token_in_token_out_and_logprobs(server):
"""
Test token-in-token-out and token_ids align with prompt_logprobs
& logprobs when return_tokens_as_token_ids is enabled.
"""
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
text = "Hello, world! How are you today?"
token_ids = tokenizer.encode(text)
async with server.get_async_client() as client:
# Test with both return_token_ids and return_tokens_as_token_ids enabled
completion = await client.completions.create(
model=MODEL_PATH,
prompt=token_ids,
max_tokens=20,
temperature=0,
echo=True,
extra_body={
"return_token_ids": True,
},
)
# Verify all fields are present
assert (
completion.choices[0].token_ids is not None
and 0 < len(completion.choices[0].token_ids) <= 20
)
assert completion.choices[0].prompt_token_ids is not None
# Decode prompt tokens
if completion.choices[0].prompt_token_ids:
prompt_text = tokenizer.decode(completion.choices[0].prompt_token_ids)
# The decoded prompt should match or close to original prompt
assert prompt_text == text