Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,16 +10,20 @@ import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import (RemoteOpenAIServerCustom,
|
||||
create_new_process_for_each_test)
|
||||
from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_test
|
||||
|
||||
# yapf: disable
|
||||
from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS, MODEL_NAME,
|
||||
TEMP_GREEDY, dummy_module)
|
||||
from tests.v1.logits_processors.utils import (
|
||||
DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS,
|
||||
MODEL_NAME,
|
||||
TEMP_GREEDY,
|
||||
dummy_module,
|
||||
prompts,
|
||||
)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
from tests.v1.logits_processors.utils import prompts
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -33,11 +37,12 @@ def _server_with_logitproc_entrypoint(
|
||||
|
||||
# Patch `entry_points` to inject logitproc entrypoint
|
||||
import importlib.metadata
|
||||
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
from vllm.entrypoints.cli import main
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
@@ -55,10 +60,11 @@ def _server_with_logitproc_module(
|
||||
|
||||
# Patch `modules` to inject dummy logitproc module
|
||||
from vllm.entrypoints.cli import main
|
||||
|
||||
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
@@ -80,8 +86,9 @@ def default_server_args():
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function",
|
||||
params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]])
|
||||
@pytest.fixture(
|
||||
scope="function", params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]
|
||||
)
|
||||
def server(default_server_args, request, monkeypatch):
|
||||
"""Consider two server configurations:
|
||||
(1) --logits-processors cli arg specifies dummy logits processor via fully-
|
||||
@@ -102,8 +109,7 @@ def server(default_server_args, request, monkeypatch):
|
||||
args = default_server_args
|
||||
_server_fxn = _server_with_logitproc_entrypoint
|
||||
|
||||
with RemoteOpenAIServerCustom(MODEL_NAME, args,
|
||||
_server_fxn) as remote_server:
|
||||
with RemoteOpenAIServerCustom(MODEL_NAME, args, _server_fxn) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -133,7 +139,7 @@ api_keyword_args = {
|
||||
)
|
||||
async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test custom logitsprocs when starting OpenAI server from CLI
|
||||
|
||||
|
||||
Launch vLLM OpenAI-compatible server, configured to load a custom logitproc
|
||||
that has a well-defined behavior (mask out all tokens except one
|
||||
`target_token`).
|
||||
@@ -157,9 +163,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# For requests which activate the dummy logitproc, choose one of
|
||||
# two `target_token` values which are known not to be EOS tokens
|
||||
request_keyword_args["extra_body"] = {
|
||||
"vllm_xargs": {
|
||||
DUMMY_LOGITPROC_ARG: target_token
|
||||
}
|
||||
"vllm_xargs": {DUMMY_LOGITPROC_ARG: target_token}
|
||||
}
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
@@ -173,8 +177,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
|
||||
choices: openai.types.CompletionChoice = batch.choices
|
||||
toks = choices[0].logprobs.tokens
|
||||
if not all([x == toks[0] for x in toks]):
|
||||
raise AssertionError(
|
||||
f"Generated {toks} should all be {toks[0]}")
|
||||
raise AssertionError(f"Generated {toks} should all be {toks[0]}")
|
||||
|
||||
# Alternate whether to activate dummy logitproc for each request
|
||||
use_dummy_logitproc = not use_dummy_logitproc
|
||||
|
||||
Reference in New Issue
Block a user