Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -5,12 +5,15 @@ import pytest
import torch
from transformers import AutoTokenizer
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN,
TOKENIZER_NAME,
DummyOutputProcessorTestVectors,
generate_dummy_prompt_logprobs_tensors,
generate_dummy_sample_logprobs)
from tests.v1.engine.utils import (
NUM_PROMPT_LOGPROBS_UNDER_TEST,
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
PROMPT_LEN,
TOKENIZER_NAME,
DummyOutputProcessorTestVectors,
generate_dummy_prompt_logprobs_tensors,
generate_dummy_sample_logprobs,
)
from vllm.engine.arg_utils import EngineArgs
from ...distributed.conftest import publisher_config, random_port # noqa: F401
@@ -31,9 +34,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
# Tokenize prompts under test & create dummy generated tokens
prompt_tokens = [
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
]
prompt_tokens = [tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS]
generation_tokens = [
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
]
@@ -42,9 +43,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
for prompt_tokens in prompt_tokens
]
prompt_strings_len = [
len(prompt_string) for prompt_string in prompt_strings
]
prompt_strings_len = [len(prompt_string) for prompt_string in prompt_strings]
return DummyOutputProcessorTestVectors(
tokenizer=tokenizer,
vllm_config=vllm_config,
@@ -58,7 +57,8 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len)
],
prompt_logprobs=[],
generation_logprobs=[])
generation_logprobs=[],
)
@pytest.fixture
@@ -76,12 +76,16 @@ def dummy_test_vectors() -> DummyOutputProcessorTestVectors:
generate_dummy_sample_logprobs(
sampled_tokens_list=tokens_list,
num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST,
tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens
tokenizer=dtv.tokenizer,
)
for tokens_list in dtv.generation_tokens
]
dtv.prompt_logprobs = [
generate_dummy_prompt_logprobs_tensors(
prompt_tokens_list=tokens_list,
num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST,
tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens
tokenizer=dtv.tokenizer,
)
for tokens_list in dtv.prompt_tokens
]
return dtv

View File

@@ -21,16 +21,16 @@ from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
TEXT_ENGINE_ARGS = AsyncEngineArgs(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
)
VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
enforce_eager=True)
VISION_ENGINE_ARGS = AsyncEngineArgs(
model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True
)
TEXT_PROMPT = "Hello my name is Robert and"
@@ -38,12 +38,11 @@ VISION_PROMPT_TEMPLATE = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
"\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
"What is in the image?<|im_end|>\n"
"<|im_start|>assistant\n")
"<|im_start|>assistant\n"
)
VISION_PROMPT = {
"prompt": VISION_PROMPT_TEMPLATE,
"multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image
},
"multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
}
@@ -70,10 +69,9 @@ async def generate(
n=n,
prompt_logprobs=prompt_logprobs,
)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
async for out in engine.generate(
request_id=request_id, prompt=prompt, sampling_params=sampling_params
):
num_tokens = sum(len(output.token_ids) for output in out.outputs)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
@@ -89,7 +87,8 @@ async def generate(
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
@@ -121,25 +120,29 @@ async def test_load(
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))
generate(
engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
)
)
)
# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
for task in pending:
task.cancel()
for task in done:
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
f"expected {NUM_EXPECTED_TOKENS}"
)
assert not engine.output_processor.has_unfinished_requests()
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
@@ -151,7 +154,6 @@ async def test_abort(
engine_args: AsyncEngineArgs,
prompt: PromptType,
):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
@@ -170,14 +172,17 @@ async def test_abort(
# Create concurrent requests.
tasks: list[asyncio.Task] = []
for idx, request_id in enumerate(request_ids):
max_tokens = (NUM_EXPECTED_TOKENS_LONG if
(idx
in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
max_tokens = (
NUM_EXPECTED_TOKENS_LONG
if (idx in REQUEST_IDS_TO_ABORT)
else NUM_EXPECTED_TOKENS
)
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
max_tokens, n)))
generate(engine, request_id, prompt, output_kind, max_tokens, n)
)
)
# API server cancels requests when they disconnect.
for idx in REQUEST_IDS_TO_ABORT:
@@ -197,7 +202,8 @@ async def test_abort(
expected_tokens = NUM_EXPECTED_TOKENS * n
assert num_generated_tokens == expected_tokens, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {expected_tokens}")
f"expected {expected_tokens}"
)
# Make sure all aborted requests were really aborted.
assert not engine.output_processor.has_unfinished_requests()
@@ -205,21 +211,21 @@ async def test_abort(
# Confirm we can do another generation.
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
task = asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS))
generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS)
)
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
assert not engine.output_processor.has_unfinished_requests()
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
@pytest.mark.asyncio
async def test_multi_abort(
monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
@@ -238,14 +244,19 @@ async def test_multi_abort(
# Create concurrent requests.
tasks: list[asyncio.Task] = []
for idx, request_id in enumerate(request_ids):
max_tokens = (NUM_EXPECTED_TOKENS_LONG if
(idx
in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
max_tokens = (
NUM_EXPECTED_TOKENS_LONG
if (idx in REQUEST_IDS_TO_ABORT)
else NUM_EXPECTED_TOKENS
)
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
tasks.append(
asyncio.create_task(
generate(engine, request_id, TEXT_PROMPT, output_kind,
max_tokens, n)))
generate(
engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n
)
)
)
# Let requests start
await asyncio.sleep(0.5)
@@ -261,25 +272,26 @@ async def test_multi_abort(
for idx, result in enumerate(results):
if idx in REQUEST_IDS_TO_ABORT:
# Aborted requests should return partial results
assert isinstance(
result, tuple
), f"Request {idx} should have completed with partial results"
assert isinstance(result, tuple), (
f"Request {idx} should have completed with partial results"
)
num_generated_tokens, request_id = result
# Should have generated some tokens before abort
assert num_generated_tokens > 0, (
f"Aborted request "
f"{request_id} should have generated some tokens")
f"Aborted request {request_id} should have generated some tokens"
)
else:
# Non-aborted requests should complete normally
assert isinstance(
result,
tuple), f"Request {idx} should have completed successfully"
assert isinstance(result, tuple), (
f"Request {idx} should have completed successfully"
)
num_generated_tokens, request_id = result
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
expected_tokens = NUM_EXPECTED_TOKENS * n
assert num_generated_tokens == expected_tokens, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {expected_tokens}")
f"expected {expected_tokens}"
)
# Make sure all aborted requests were cleaned up
assert not engine.output_processor.has_unfinished_requests()
@@ -297,7 +309,6 @@ async def test_finished_flag(
engine_args: AsyncEngineArgs,
prompt: PromptType,
):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
@@ -314,9 +325,9 @@ async def test_finished_flag(
)
outputs = [
out
async for out in engine.generate(request_id="request-33",
prompt=prompt,
sampling_params=sampling_params)
async for out in engine.generate(
request_id="request-33", prompt=prompt, sampling_params=sampling_params
)
]
# Assert only the last output has the finished flag set
@@ -329,9 +340,9 @@ async def test_finished_flag(
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
engine_args: AsyncEngineArgs,
prompt: PromptType):
async def test_mid_stream_cancellation(
monkeypatch: pytest.MonkeyPatch, engine_args: AsyncEngineArgs, prompt: PromptType
):
"""Test that requests can be cancelled mid-stream."""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
@@ -358,7 +369,9 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
RequestOutputKind.DELTA,
NUM_TOKENS,
cancel_after=NUM_EXPECTED_TOKENS,
)))
)
)
)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
@@ -367,7 +380,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
for num_generated_tokens, request_id in results:
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} tokens but "
f"expected to cancel after {NUM_EXPECTED_TOKENS}")
f"expected to cancel after {NUM_EXPECTED_TOKENS}"
)
# Make sure no requests are left hanging
assert not engine.output_processor.has_unfinished_requests()
@@ -375,15 +389,16 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
# Confirm we can reuse the request id after the cancellations.
request_id = request_ids[0]
task = asyncio.create_task(
generate(engine, request_id, prompt, RequestOutputKind.DELTA,
NUM_EXPECTED_TOKENS))
generate(
engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS
)
)
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
assert not engine.output_processor.has_unfinished_requests()
class MockLoggingStatLogger(LoggingStatLogger):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
self.log = MagicMock()
@@ -410,8 +425,7 @@ async def test_customize_loggers(monkeypatch):
stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(stat_loggers) == 1
assert len(
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
assert len(stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
stat_loggers[0][0].log.assert_called_once()
@@ -424,24 +438,30 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33)
sampling_params = SamplingParams(
max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33,
)
# Test with valid DP rank.
async for _ in engine.generate(request_id="request-34",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=0):
async for _ in engine.generate(
request_id="request-34",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=0,
):
pass
# Test with out-of-range DP rank.
with pytest.raises(ValueError):
async for _ in engine.generate(request_id="request-35",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=1):
async for _ in engine.generate(
request_id="request-35",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=1,
):
pass
@@ -465,10 +485,14 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch):
await engine.check_health()
# Test 2: Mock the errored property to simulate a dead engine
with patch.object(type(engine),
'errored',
new_callable=lambda: property(lambda self: True)
), pytest.raises(EngineDeadError):
with (
patch.object(
type(engine),
"errored",
new_callable=lambda: property(lambda self: True),
),
pytest.raises(EngineDeadError),
):
await engine.check_health()
# Test 3: Verify healthy engine still works after mock
@@ -476,7 +500,8 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch):
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
@pytest.mark.asyncio
async def test_abort_final_output(
monkeypatch: pytest.MonkeyPatch,
@@ -504,8 +529,8 @@ async def test_abort_final_output(
outputs: list[RequestOutput] = []
generated = asyncio.create_task(
collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params,
outputs))
collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs)
)
# Let it generate some tokens
await asyncio.sleep(0.5)
@@ -525,14 +550,13 @@ async def test_abort_final_output(
assert final_output.outputs[0].stop_reason is None
# Verify num_cached_tokens is set correctly
assert hasattr(final_output, 'num_cached_tokens')
assert hasattr(final_output, "num_cached_tokens")
assert final_output.num_cached_tokens >= 0
# If we got intermediate outputs, verify they are consistent
if output_kind == RequestOutputKind.DELTA:
# For DELTA, sum all intermediate tokens should <= final tokens
token_count = sum(
len(output.outputs[0].token_ids) for output in outputs)
token_count = sum(len(output.outputs[0].token_ids) for output in outputs)
assert token_count > 0
# This would ordinarily be 0, but could end up > 0 if the
# final abort is coalesced with another chunk in the output queue.
@@ -554,9 +578,9 @@ async def collect_outputs(
) -> Optional[RequestOutput]:
"""Helper to collect outputs and return the final one."""
final_output: Optional[RequestOutput] = None
async for output in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
async for output in engine.generate(
request_id=request_id, prompt=prompt, sampling_params=sampling_params
):
if not output.finished:
outputs_list.append(output)
final_output = output

View File

@@ -22,8 +22,9 @@ def test_prefix_caching_from_cli():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert (vllm_config.cache_config.enable_prefix_caching
), "V1 turns on prefix caching by default."
assert vllm_config.cache_config.enable_prefix_caching, (
"V1 turns on prefix caching by default."
)
# Turn it off possible with flag.
args = parser.parse_args(["--no-enable-prefix-caching"])
@@ -41,8 +42,7 @@ def test_prefix_caching_from_cli():
# set hash algorithm to sha256_cbor
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == \
"sha256_cbor"
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256_cbor"
# set hash algorithm to sha256
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
@@ -57,10 +57,10 @@ def test_prefix_caching_from_cli():
def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(
UsageContext.LLM_CLASS)
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)
from vllm.platforms import current_platform
device_name = current_platform.get_device_name().lower()
if "h100" in device_name or "h200" in device_name:
# For H100 and H200, we use larger default values.
@@ -76,7 +76,6 @@ def test_defaults_with_usage_context():
assert vllm_config.scheduler_config.max_num_batched_tokens == default_llm_tokens # noqa: E501
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config = engine_args.create_engine_config(
UsageContext.OPENAI_API_SERVER)
vllm_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER)
assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
assert vllm_config.scheduler_config.max_num_batched_tokens == default_server_tokens # noqa: E501

View File

@@ -22,8 +22,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from ...utils import create_new_process_for_each_test, multi_gpu_test
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -48,7 +47,6 @@ def make_request() -> EngineCoreRequest:
@create_new_process_for_each_test()
def test_engine_core(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
@@ -57,14 +55,13 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
engine_core = EngineCore(
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
)
"""Test basic request lifecycle."""
# First request.
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
@@ -73,8 +70,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 1
# Second request.
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 1
@@ -83,10 +79,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 2
# Add two requests in a row.
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
engine_core.add_request(
*engine_core.preprocess_add_request(make_request()))
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
assert len(engine_core.scheduler.waiting) == 2
assert len(engine_core.scheduler.running) == 2
@@ -196,9 +190,9 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
engine_core = EngineCore(
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
)
"""Test basic request lifecycle."""
# First request.
request: EngineCoreRequest = make_request()
@@ -238,17 +232,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
Test that the engine can handle multiple concurrent batches.
"""
def make_request_with_max_tokens(req_id: str,
max_tokens: int) -> EngineCoreRequest:
def make_request_with_max_tokens(req_id: str, max_tokens: int) -> EngineCoreRequest:
request = make_request()
request.request_id = req_id
request.sampling_params.max_tokens = max_tokens
return request
class DummyExecutor(UniProcExecutor):
def initialize_from_config(
self, kv_cache_configs: list[KVCacheConfig]) -> None:
def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs)
# Create a thread pool with a single worker
@@ -265,8 +256,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert non_block
def _execute():
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
output = self.collective_rpc("execute_model", args=(scheduler_output,))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])
@@ -279,7 +269,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
return 2
def shutdown(self):
if hasattr(self, 'thread_pool'):
if hasattr(self, "thread_pool"):
self.thread_pool.shutdown(wait=False)
with monkeypatch.context() as m:
@@ -297,9 +287,9 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
)
vllm_config = engine_args.create_engine_config()
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False,
executor_class=DummyExecutor)
engine_core = EngineCore(
vllm_config=vllm_config, log_stats=False, executor_class=DummyExecutor
)
assert engine_core.batch_queue is not None
# Add two requests in a row. Each request have 12 prompt tokens.
@@ -314,8 +304,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 10
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10
assert engine_core.scheduler.requests[req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue()[0] == {}
@@ -366,8 +355,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert output is not None
assert len(output[0].outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
assert (
engine_core.scheduler.requests[req_id].num_tokens
== expected_num_tokens[req_id]
)
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
@@ -391,17 +382,19 @@ def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
engine_core = EngineCore(
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
)
def get_worker_cache_config_field(worker, key: str):
return getattr(worker.cache_config, key)
num_gpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_gpu_blocks", ))
get_worker_cache_config_field, args=("num_gpu_blocks",)
)
num_cpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_cpu_blocks", ))
get_worker_cache_config_field, args=("num_cpu_blocks",)
)
assert all(x is not None for x in num_gpu_blocks)
assert all(x is not None for x in num_cpu_blocks)
@@ -417,40 +410,35 @@ def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch):
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
engine_core = EngineCore(
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
)
# Test with UUID object (common mistake)
uuid_request = make_request()
uuid_request.request_id = uuid.uuid4() # UUID object instead of string
with pytest.raises(TypeError,
match="request_id must be a string, got.*UUID"):
engine_core.add_request(
*engine_core.preprocess_add_request(uuid_request))
with pytest.raises(TypeError, match="request_id must be a string, got.*UUID"):
engine_core.add_request(*engine_core.preprocess_add_request(uuid_request))
# Test with integer
int_request = make_request()
int_request.request_id = 12345
with pytest.raises(TypeError,
match="request_id must be a string, got.*int"):
engine_core.add_request(
*engine_core.preprocess_add_request(int_request))
with pytest.raises(TypeError, match="request_id must be a string, got.*int"):
engine_core.add_request(*engine_core.preprocess_add_request(int_request))
# Test with None
none_request = make_request()
none_request.request_id = None
with pytest.raises(TypeError,
match="request_id must be a string, got.*NoneType"):
engine_core.add_request(
*engine_core.preprocess_add_request(none_request))
with pytest.raises(
TypeError, match="request_id must be a string, got.*NoneType"
):
engine_core.add_request(*engine_core.preprocess_add_request(none_request))
# Verify engine is still functional after errors
valid_request = make_request()
engine_core.add_request(
*engine_core.preprocess_add_request(valid_request))
engine_core.add_request(*engine_core.preprocess_add_request(valid_request))
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0

View File

@@ -17,16 +17,14 @@ from transformers import AutoTokenizer
from tests.utils import multi_gpu_test
from vllm import SamplingParams
from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
ZmqEventPublisher)
from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublisher
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient
from vllm.v1.engine.utils import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
@@ -34,8 +32,7 @@ from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -44,8 +41,8 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request(
params: SamplingParams,
prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest:
params: SamplingParams, prompt_tokens_ids: Optional[list[int]] = None
) -> EngineCoreRequest:
if not prompt_tokens_ids:
prompt_tokens_ids = PROMPT_TOKENS
@@ -64,7 +61,6 @@ def make_request(
def loop_until_done(client: EngineCoreClient, outputs: dict):
while True:
engine_core_outputs = client.get_output().outputs
@@ -82,7 +78,6 @@ def loop_until_done(client: EngineCoreClient, outputs: dict):
async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
while True:
engine_core_outputs = (await client.get_output_async()).outputs
@@ -100,7 +95,6 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
while True:
engine_core_outputs = (await client.get_output_async()).outputs
@@ -119,10 +113,9 @@ async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
# Dummy utility function to monkey-patch into engine core.
def echo(self,
msg: str,
err_msg: Optional[str] = None,
sleep: Optional[float] = None) -> str:
def echo(
self, msg: str, err_msg: Optional[str] = None, sleep: Optional[float] = None
) -> str:
print(f"echo util function called: {msg}, {err_msg}")
if sleep is not None:
time.sleep(sleep)
@@ -133,9 +126,9 @@ def echo(self,
@create_new_process_for_each_test()
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool):
def test_engine_core_client(
monkeypatch: pytest.MonkeyPatch, multiprocessing_mode: bool
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@@ -143,8 +136,7 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
m.setattr(EngineCore, "echo", echo, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
@@ -172,7 +164,8 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{outputs[req_id]=}, {MAX_TOKENS=}")
f"{outputs[req_id]=}, {MAX_TOKENS=}"
)
"""Abort Request Cycle."""
# Note: this code pathway will only work for multiprocessing
@@ -191,10 +184,12 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
)
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
)
"""Abort after request is finished."""
# Note: this code pathway will only work for multiprocessing
@@ -202,7 +197,7 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
request = requests[0]
client.add_request(request)
time.sleep(10.)
time.sleep(10.0)
client.abort_requests([request.request_id])
@@ -222,7 +217,6 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@@ -231,7 +225,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
usage_context=UsageContext.UNKNOWN_CONTEXT
)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
@@ -261,7 +256,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{outputs[req_id]=}, {MAX_TOKENS=}")
f"{outputs[req_id]=}, {MAX_TOKENS=}"
)
"""Abort Request Cycle."""
# Add requests to the engine.
@@ -277,10 +273,12 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
)
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
)
"""Utility method invocation"""
core_client: AsyncMPClient = client
@@ -296,8 +294,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
# Test that cancelling the utility call doesn't destabilize the
# engine.
util_task = asyncio.create_task(
core_client.call_utility_async("echo", "testarg2", None,
0.5)) # sleep for 0.5 sec
core_client.call_utility_async("echo", "testarg2", None, 0.5)
) # sleep for 0.5 sec
await asyncio.sleep(0.05)
cancelled = util_task.cancel()
assert cancelled
@@ -305,9 +303,9 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
# Ensure client is still functional. The engine runs utility
# methods in a single thread so this request won't be processed
# until the cancelled sleeping one is complete.
result = await asyncio.wait_for(core_client.call_utility_async(
"echo", "testarg3"),
timeout=1.0)
result = await asyncio.wait_for(
core_client.call_utility_async("echo", "testarg3"), timeout=1.0
)
assert result == "testarg3"
finally:
client.shutdown()
@@ -353,8 +351,7 @@ def echo_dc_nested(
msg: str,
structure_type: str = "list_of_dicts",
) -> Any:
print(f"echo dc nested util function called: {msg}, "
f"structure: {structure_type}")
print(f"echo dc nested util function called: {msg}, structure: {structure_type}")
val = None if msg is None else MyDataclass(msg)
if structure_type == "list_of_dicts": # noqa
@@ -373,8 +370,8 @@ def echo_dc_nested(
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_custom_return(
monkeypatch: pytest.MonkeyPatch):
monkeypatch: pytest.MonkeyPatch,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@@ -386,7 +383,8 @@ async def test_engine_core_client_util_method_custom_return(
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
usage_context=UsageContext.UNKNOWN_CONTEXT
)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
@@ -402,22 +400,17 @@ async def test_engine_core_client_util_method_custom_return(
# Test utility method returning custom / non-native data type.
core_client: AsyncMPClient = client
result = await core_client.call_utility_async(
"echo_dc", "testarg2", False)
assert isinstance(result,
MyDataclass) and result.message == "testarg2"
result = await core_client.call_utility_async(
"echo_dc", "testarg2", True)
result = await core_client.call_utility_async("echo_dc", "testarg2", False)
assert isinstance(result, MyDataclass) and result.message == "testarg2"
result = await core_client.call_utility_async("echo_dc", "testarg2", True)
assert isinstance(result, list) and all(
isinstance(r, MyDataclass) and r.message == "testarg2"
for r in result)
isinstance(r, MyDataclass) and r.message == "testarg2" for r in result
)
# Test returning None and list of Nones
result = await core_client.call_utility_async(
"echo_dc", None, False)
result = await core_client.call_utility_async("echo_dc", None, False)
assert result is None
result = await core_client.call_utility_async(
"echo_dc", None, True)
result = await core_client.call_utility_async("echo_dc", None, True)
assert isinstance(result, list) and all(r is None for r in result)
finally:
@@ -426,8 +419,8 @@ async def test_engine_core_client_util_method_custom_return(
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_custom_dict_return(
monkeypatch: pytest.MonkeyPatch):
monkeypatch: pytest.MonkeyPatch,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@@ -439,7 +432,8 @@ async def test_engine_core_client_util_method_custom_dict_return(
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
usage_context=UsageContext.UNKNOWN_CONTEXT
)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
@@ -457,22 +451,21 @@ async def test_engine_core_client_util_method_custom_dict_return(
# Test single object return
result = await core_client.call_utility_async(
"echo_dc_dict", "testarg3", False)
assert isinstance(result,
MyDataclass) and result.message == "testarg3"
"echo_dc_dict", "testarg3", False
)
assert isinstance(result, MyDataclass) and result.message == "testarg3"
# Test dict return with custom value types
result = await core_client.call_utility_async(
"echo_dc_dict", "testarg3", True)
"echo_dc_dict", "testarg3", True
)
assert isinstance(result, dict) and len(result) == 3
for key, val in result.items():
assert key in ["key1", "key2", "key3"]
assert isinstance(val,
MyDataclass) and val.message == "testarg3"
assert isinstance(val, MyDataclass) and val.message == "testarg3"
# Test returning dict with None values
result = await core_client.call_utility_async(
"echo_dc_dict", None, True)
result = await core_client.call_utility_async("echo_dc_dict", None, True)
assert isinstance(result, dict) and len(result) == 3
for key, val in result.items():
assert key in ["key1", "key2", "key3"]
@@ -484,8 +477,8 @@ async def test_engine_core_client_util_method_custom_dict_return(
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_nested_structures(
monkeypatch: pytest.MonkeyPatch):
monkeypatch: pytest.MonkeyPatch,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@@ -497,7 +490,8 @@ async def test_engine_core_client_util_method_nested_structures(
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
usage_context=UsageContext.UNKNOWN_CONTEXT
)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
@@ -514,42 +508,48 @@ async def test_engine_core_client_util_method_nested_structures(
# Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
result = await core_client.call_utility_async(
"echo_dc_nested", "nested1", "list_of_dicts")
"echo_dc_nested", "nested1", "list_of_dicts"
)
assert isinstance(result, list) and len(result) == 2
for i, item in enumerate(result):
assert isinstance(item, dict)
if i == 0:
assert "a" in item and "b" in item
assert isinstance(
item["a"],
MyDataclass) and item["a"].message == "nested1"
assert isinstance(
item["b"],
MyDataclass) and item["b"].message == "nested1"
assert (
isinstance(item["a"], MyDataclass)
and item["a"].message == "nested1"
)
assert (
isinstance(item["b"], MyDataclass)
and item["b"].message == "nested1"
)
else:
assert "c" in item and "d" in item
assert isinstance(
item["c"],
MyDataclass) and item["c"].message == "nested1"
assert isinstance(
item["d"],
MyDataclass) and item["d"].message == "nested1"
assert (
isinstance(item["c"], MyDataclass)
and item["c"].message == "nested1"
)
assert (
isinstance(item["d"], MyDataclass)
and item["d"].message == "nested1"
)
# Test dict of lists: {"list1": [val, val], "list2": [val, val]}
result = await core_client.call_utility_async(
"echo_dc_nested", "nested2", "dict_of_lists")
"echo_dc_nested", "nested2", "dict_of_lists"
)
assert isinstance(result, dict) and len(result) == 2
assert "list1" in result and "list2" in result
for key, lst in result.items():
assert isinstance(lst, list) and len(lst) == 2
for item in lst:
assert isinstance(
item, MyDataclass) and item.message == "nested2"
assert isinstance(item, MyDataclass) and item.message == "nested2"
# Test deeply nested: {"outer": [{"inner": [val, val]},
# {"inner": [val]}]}
result = await core_client.call_utility_async(
"echo_dc_nested", "nested3", "deep_nested")
"echo_dc_nested", "nested3", "deep_nested"
)
assert isinstance(result, dict) and "outer" in result
outer_list = result["outer"]
assert isinstance(outer_list, list) and len(outer_list) == 2
@@ -560,21 +560,22 @@ async def test_engine_core_client_util_method_nested_structures(
inner_list1 = inner_dict1["inner"]
assert isinstance(inner_list1, list) and len(inner_list1) == 2
for item in inner_list1:
assert isinstance(item,
MyDataclass) and item.message == "nested3"
assert isinstance(item, MyDataclass) and item.message == "nested3"
# Second dict in outer list should have "inner" with 1 item
inner_dict2 = outer_list[1]
assert isinstance(inner_dict2, dict) and "inner" in inner_dict2
inner_list2 = inner_dict2["inner"]
assert isinstance(inner_list2, list) and len(inner_list2) == 1
assert isinstance(
inner_list2[0],
MyDataclass) and inner_list2[0].message == "nested3"
assert (
isinstance(inner_list2[0], MyDataclass)
and inner_list2[0].message == "nested3"
)
# Test with None values in nested structures
result = await core_client.call_utility_async(
"echo_dc_nested", None, "list_of_dicts")
"echo_dc_nested", None, "list_of_dicts"
)
assert isinstance(result, list) and len(result) == 2
for item in result:
assert isinstance(item, dict)
@@ -595,7 +596,6 @@ def test_kv_cache_events(
multiprocessing_mode: bool,
publisher_config,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
block_size = 16
@@ -609,8 +609,7 @@ def test_kv_cache_events(
)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
@@ -622,9 +621,9 @@ def test_kv_cache_events(
log_stats=False,
)
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(endpoint,
topic=publisher_config.topic,
decode_type=KVEventBatch)
subscriber = MockSubscriber(
endpoint, topic=publisher_config.topic, decode_type=KVEventBatch
)
try:
custom_tokens = list(range(num_blocks * block_size))
@@ -641,22 +640,25 @@ def test_kv_cache_events(
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert (len(received.events) == 1
), "We should have exactly one BlockStored event"
assert len(received.events) == 1, (
"We should have exactly one BlockStored event"
)
event = received.events[0]
assert isinstance(
event, BlockStored), "We should have a BlockStored event"
assert (len(event.block_hashes) == num_blocks
), "We should have a BlockStored event with 2 block_hashes"
assert (event.block_size == block_size
), "Block size should be the same as the block size"
assert (event.parent_block_hash
is None), "Parent block hash should be None"
assert isinstance(event, BlockStored), "We should have a BlockStored event"
assert len(event.block_hashes) == num_blocks, (
"We should have a BlockStored event with 2 block_hashes"
)
assert event.block_size == block_size, (
"Block size should be the same as the block size"
)
assert event.parent_block_hash is None, "Parent block hash should be None"
assert event.lora_id is None, "Lora id should be None"
assert (len(event.token_ids) == num_blocks * block_size
), "Token ids should be the same as the custom tokens"
assert (event.token_ids == custom_tokens
), "Token ids should be the same as the custom tokens"
assert len(event.token_ids) == num_blocks * block_size, (
"Token ids should be the same as the custom tokens"
)
assert event.token_ids == custom_tokens, (
"Token ids should be the same as the custom tokens"
)
finally:
client.shutdown()
subscriber.close()
@@ -674,7 +676,6 @@ async def test_kv_cache_events_dp(
multiprocessing_mode: bool,
publisher_config,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
block_size = 16
@@ -692,8 +693,7 @@ async def test_kv_cache_events_dp(
)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
@@ -710,13 +710,12 @@ async def test_kv_cache_events_dp(
base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
endpoints = []
for i in range(dp_size):
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(
base_endpoint, i)
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(base_endpoint, i)
endpoints.append(offset_endpoint)
subscriber = MockSubscriber(endpoints,
topic=publisher_config.topic,
decode_type=KVEventBatch)
subscriber = MockSubscriber(
endpoints, topic=publisher_config.topic, decode_type=KVEventBatch
)
try:
custom_tokens = list(range(num_blocks * block_size))
@@ -734,15 +733,12 @@ async def test_kv_cache_events_dp(
await asyncio.sleep(0.1)
# Initialize outputs dict for all requests
outputs: dict[str, list] = {
req_id: []
for req_id in all_request_ids
}
outputs: dict[str, list] = {req_id: [] for req_id in all_request_ids}
print("processing requests...")
await asyncio.wait_for(loop_until_fully_done_async(
client, outputs),
timeout=20.0)
await asyncio.wait_for(
loop_until_fully_done_async(client, outputs), timeout=20.0
)
# Receive from subscriber until no more messages
print("collecting results...")
@@ -755,13 +751,11 @@ async def test_kv_cache_events_dp(
results.append(result)
# Collect all events and data_parallel_ranks from all results
all_dp_ranks = [
received.data_parallel_rank for (_, received) in results
]
all_dp_ranks = [received.data_parallel_rank for (_, received) in results]
unique_dps = set(all_dp_ranks)
assert (
len(unique_dps) == 2
), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
assert len(unique_dps) == 2, (
f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
)
finally:
client.shutdown()
@@ -770,7 +764,6 @@ async def test_kv_cache_events_dp(
@pytest.mark.timeout(20)
def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, pytest.raises(Exception) as e_info:
m.setenv("VLLM_USE_V1", "1")
@@ -787,7 +780,8 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
t = time.time()
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
usage_context=UsageContext.UNKNOWN_CONTEXT
)
executor_class = Executor.get_class(vllm_config)
print(f"VllmConfig creation took {time.time() - t:.2f} seconds.")
@@ -815,8 +809,7 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
@create_new_process_for_each_test()
def test_engine_core_proc_instantiation_cuda_empty(
monkeypatch: pytest.MonkeyPatch):
def test_engine_core_proc_instantiation_cuda_empty(monkeypatch: pytest.MonkeyPatch):
"""
Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES
is empty. This ensures the engine frontend does not need access to GPUs.
@@ -833,17 +826,13 @@ def test_engine_core_proc_instantiation_cuda_empty(
# Only implement the methods that are actually called during init
from vllm.v1.kv_cache_interface import FullAttentionSpec
mock_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16)
mock_executor.get_kv_cache_specs.return_value = [{
"default": mock_spec
}]
mock_executor.determine_available_memory.return_value = [
1024 * 1024 * 1024
]
mock_spec = FullAttentionSpec(
block_size=16, num_kv_heads=1, head_size=64, dtype=torch.float16
)
mock_executor.get_kv_cache_specs.return_value = [{"default": mock_spec}]
mock_executor.determine_available_memory.return_value = [1024 * 1024 * 1024]
mock_executor.initialize_from_config.return_value = None
mock_executor.max_concurrent_batches = 1
@@ -857,19 +846,22 @@ def test_engine_core_proc_instantiation_cuda_empty(
from vllm.v1.engine.utils import EngineZmqAddresses
def mock_startup_handshake(self, handshake_socket, local_client,
headless, parallel_config):
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
outputs=["tcp://127.0.0.1:5556"],
coordinator_input=None,
coordinator_output=None)
def mock_startup_handshake(
self, handshake_socket, local_client, headless, parallel_config
):
return EngineZmqAddresses(
inputs=["tcp://127.0.0.1:5555"],
outputs=["tcp://127.0.0.1:5556"],
coordinator_input=None,
coordinator_output=None,
)
# Background processes are not important here
m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake)
vllm_config = EngineArgs(
model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True).create_engine_config()
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
).create_engine_config()
engine_core_proc = EngineCoreProc(
vllm_config=vllm_config,
local_client=True,

View File

@@ -40,23 +40,139 @@ def test_fast_inc_detok_invalid_utf8_err_case():
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", (
"Should use FastIncrementalDetokenizer by default"
)
# Process tokens incrementally
test_tokens = [
236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908,
147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292,
827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418,
569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118,
35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140,
236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654,
236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654,
236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817,
4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509,
19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398,
432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745,
2555, 513, 236789, 602, 31118, 569
236840,
107,
138,
236782,
107,
140,
236775,
6265,
1083,
623,
121908,
147418,
827,
107,
140,
236775,
6265,
236779,
2084,
1083,
623,
203292,
827,
107,
140,
236775,
6265,
236779,
7777,
1083,
623,
121908,
147418,
569,
537,
236789,
65880,
569,
537,
236789,
62580,
853,
115693,
210118,
35178,
16055,
1270,
759,
215817,
4758,
1925,
1117,
827,
107,
140,
236775,
5654,
1083,
623,
110733,
46291,
827,
107,
140,
236775,
5654,
236779,
2084,
1083,
623,
136955,
56731,
827,
107,
140,
236775,
5654,
236779,
7777,
1083,
623,
194776,
2947,
496,
109811,
1608,
890,
215817,
4758,
1925,
1117,
2789,
432,
398,
602,
31118,
569,
124866,
134772,
509,
19478,
1640,
33779,
236743,
236770,
236819,
236825,
236771,
432,
398,
432,
237167,
827,
107,
140,
236775,
77984,
1083,
623,
2709,
236745,
2555,
513,
236789,
602,
31118,
569,
]
output = ""
@@ -66,8 +182,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
finished = i == len(test_tokens) - 1
output += detokenizer.get_next_output_text(finished, delta=True)
# fmt: off
# fmt: off
assert output == r'''[
{
"source": "Résultats",

View File

@@ -43,7 +43,8 @@ def _vllm_model(
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
params=[False, True],
)
def vllm_model(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture parameterized by APC True/False."""
with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model:
@@ -62,14 +63,15 @@ def vllm_model_apc(vllm_runner, monkeypatch):
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
params=[False, True],
)
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture with APC."""
with _vllm_model(
request.param,
vllm_runner,
monkeypatch,
skip_tokenizer_init=True,
request.param,
vllm_runner,
monkeypatch,
skip_tokenizer_init=True,
) as vllm_model:
yield vllm_model
@@ -97,9 +99,11 @@ def _get_test_sampling_params(
top_p=0.95,
n=n,
seed=seed,
structured_outputs=StructuredOutputsParams(
regex="[0-9]+") if structured_outputs else None,
) for n in n_list
structured_outputs=StructuredOutputsParams(regex="[0-9]+")
if structured_outputs
else None,
)
for n in n_list
], n_list
@@ -132,23 +136,20 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
for out, n in zip(outputs, n_list):
completion_counts: dict[str, int] = {}
# Assert correct number of completions
assert len(out.outputs) == n, (
f"{len(out.outputs)} completions; {n} expected.")
assert len(out.outputs) == n, f"{len(out.outputs)} completions; {n} expected."
for idx in range(n):
comp = out.outputs[idx]
# Assert correct completion indices
assert comp.index == idx, (f"Index {comp.index}; expected {idx}.")
assert comp.index == idx, f"Index {comp.index}; expected {idx}."
text = comp.text
completion_counts[text] = completion_counts.get(text, 0) + 1
# Assert unique completions
if len(completion_counts) != n:
repeats = {
txt: num
for (txt, num) in completion_counts.items() if num > 1
}
repeats = {txt: num for (txt, num) in completion_counts.items() if num > 1}
raise AssertionError(
f"{len(completion_counts)} unique completions; expected"
f" {n}. Repeats: {repeats}")
f" {n}. Repeats: {repeats}"
)
def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
@@ -162,13 +163,12 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
}
monkeypatch.setenv("VLLM_USE_V1", "1")
with vllm_runner(
MODEL,
speculative_config=speculative_config,
disable_log_stats=False,
MODEL,
speculative_config=speculative_config,
disable_log_stats=False,
) as vllm_model:
llm: LLM = vllm_model.llm
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens)
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = llm.generate(example_prompts, sampling_params)
n_prompts = len(example_prompts)
@@ -192,15 +192,14 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
num_requests_running = find_metric("vllm:num_requests_running")
assert len(num_requests_running) == 1
assert isinstance(num_requests_running[0], Gauge)
assert num_requests_running[0].value == .0
assert num_requests_running[0].value == 0.0
generation_tokens = find_metric("vllm:generation_tokens")
assert len(generation_tokens) == 1
assert isinstance(generation_tokens[0], Counter)
assert generation_tokens[0].value == total_tokens
request_generation_tokens = find_metric(
"vllm:request_generation_tokens")
request_generation_tokens = find_metric("vllm:request_generation_tokens")
assert len(request_generation_tokens) == 1
assert isinstance(request_generation_tokens[0], Histogram)
assert "+Inf" in request_generation_tokens[0].buckets
@@ -209,15 +208,15 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
assert request_generation_tokens[0].sum == total_tokens
num_accepted_tokens_per_pos = find_metric(
"vllm:spec_decode_num_accepted_tokens_per_pos")
"vllm:spec_decode_num_accepted_tokens_per_pos"
)
assert len(num_accepted_tokens_per_pos) == 1
assert isinstance(num_accepted_tokens_per_pos[0], Vector)
assert len(num_accepted_tokens_per_pos[0].values) == 5
@pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"])
def test_skip_tokenizer_initialization(model: str,
monkeypatch: pytest.MonkeyPatch):
def test_skip_tokenizer_initialization(model: str, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "1")
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
@@ -232,8 +231,9 @@ def test_skip_tokenizer_initialization(model: str,
with pytest.raises(ValueError, match="cannot pass text prompts when"):
llm.generate("abc", sampling_params)
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
outputs = llm.generate(
{"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params
)
assert len(outputs) > 0
completions = outputs[0].outputs
assert len(completions) > 0

View File

@@ -7,19 +7,20 @@ from typing import Optional
import pytest
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
STOP_STRINGS,
DummyOutputProcessorTestVectors,
MockEngineCore)
from tests.v1.engine.utils import (
NUM_PROMPT_LOGPROBS_UNDER_TEST,
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
STOP_STRINGS,
DummyOutputProcessorTestVectors,
MockEngineCore,
)
from vllm import PoolingParams
from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.metrics.stats import IterationStats
@@ -40,33 +41,34 @@ def _ref_convert_id_to_token(
@pytest.mark.parametrize(
"request_output_kind",
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens)
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
def test_incremental_detokenization(
request_output_kind: RequestOutputKind, dummy_test_vectors
):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
# Make N requests.
requests = [
EngineCoreRequest(request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
),
pooling_params=None)
EngineCoreRequest(
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
),
pooling_params=None,
)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
@@ -102,8 +104,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
# Confirmed tracked values matches what we expected.
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
zip(dummy_test_vectors.generation_strings,
dummy_test_vectors.generation_tokens)):
zip(dummy_test_vectors.generation_strings, dummy_test_vectors.generation_tokens)
):
gen_str = gen_strings[f"request-{idx}"]
gen_toks = gen_tokens[f"request-{idx}"]
@@ -134,9 +136,11 @@ def _validate_logprobs(
ref_prompt_logprobs = dtv.prompt_logprobs[req_idx]
if num_sample_logprobs is not None:
# Validate sample logprobs
assert logprobs is not None, (f"Request {req_id} requires sample"
" logprobs but sample logprobs are"
" None.")
assert logprobs is not None, (
f"Request {req_id} requires sample"
" logprobs but sample logprobs are"
" None."
)
# Require num sampled tokens to match num
# sampled logprobs - especially important
# to check since the detokenizer can cause
@@ -147,44 +151,51 @@ def _validate_logprobs(
assert num_new_tokens == len_sample_logprobs, (
f"Request {req_id} has {num_new_tokens}"
" completion tokens but has"
f" {len_sample_logprobs} sample logprobs.")
f" {len_sample_logprobs} sample logprobs."
)
ref_cumulative_logprob = 0.0
for idx, (sampled_token,
pos_logprob_dict) in enumerate(zip(new_tokens,
logprobs)):
for idx, (sampled_token, pos_logprob_dict) in enumerate(
zip(new_tokens, logprobs)
):
# Break out the reference log probability value &
# logprob token id tensors associated with this
# position in the completion. Also break out the
# sampled token ranks
(ref_pos_logprob_toks, ref_pos_logprob_vals,
ref_sampled_token_rank) = ref_logprobs[idx]
(ref_pos_logprob_toks, ref_pos_logprob_vals, ref_sampled_token_rank) = (
ref_logprobs[idx]
)
# For each position in the completion sequence,
# ensure the actual sampled token is among the
# logprobs
assert sampled_token in pos_logprob_dict, (
f"Sampled token {sampled_token} not"
f" present in logprob at index {idx}")
f" present in logprob at index {idx}"
)
# Validate number of sample logprobs
num_lp_toks = len(pos_logprob_dict)
assert (num_lp_toks == num_sample_logprobs
or num_lp_toks == num_sample_logprobs +
1), ("Valid numbers of sample logprobs are"
f" {num_sample_logprobs} or"
f" {num_sample_logprobs+1} but"
f" {num_lp_toks} logprobs found at"
f" position {idx}. Logprobs dict:"
f" {pos_logprob_dict}")
assert (
num_lp_toks == num_sample_logprobs
or num_lp_toks == num_sample_logprobs + 1
), (
"Valid numbers of sample logprobs are"
f" {num_sample_logprobs} or"
f" {num_sample_logprobs + 1} but"
f" {num_lp_toks} logprobs found at"
f" position {idx}. Logprobs dict:"
f" {pos_logprob_dict}"
)
# Validate sampled token logprob rank
smp_lp = pos_logprob_dict[sampled_token]
smp_lp_rank = smp_lp.rank
assert (ref_sampled_token_rank == smp_lp_rank), (
assert ref_sampled_token_rank == smp_lp_rank, (
"Sampled token logprob rank"
f" {smp_lp_rank} does not match"
" correct value"
f" {ref_sampled_token_rank}"
f" in Logprob {smp_lp}")
f" in Logprob {smp_lp}"
)
# Validate that the logprob processor yields
# the correct log probabilities and valid
@@ -198,7 +209,8 @@ def _validate_logprobs(
ref_tok_id = ref_pos_logprob_toks[jdx]
assert ref_tok_id in pos_logprob_dict, (
f"Expected token {ref_tok_id} to be"
f" in logprob dict but it is not.")
f" in logprob dict but it is not."
)
# Extract actually-generated logprob
# info
@@ -208,40 +220,43 @@ def _validate_logprobs(
# A "top" (rank 1) logprob must be
# present
rank_one_appears = (True
if lp_rank == 1 else rank_one_appears)
rank_one_appears = True if lp_rank == 1 else rank_one_appears
# Rank must be >= 1
assert lp_rank >= 1, (f"Logprob {lp} has invalid"
f" rank {lp_rank} < 1."
f" Logprob dict: {pos_logprob_dict}")
assert lp_rank >= 1, (
f"Logprob {lp} has invalid"
f" rank {lp_rank} < 1."
f" Logprob dict: {pos_logprob_dict}"
)
# Validate log probability
assert math.isclose(lp_val, ref_lp_val), (
f"Token id {ref_tok_id} appears in logprobs dict"
f" at position {idx} in completion with log"
f" probability {lp_val} but {ref_lp_val} was"
f" expected. Logprob: {lp}")
f" expected. Logprob: {lp}"
)
assert rank_one_appears, (f"No Logprob has rank 1"
" in the following Logprob"
f" dict: {pos_logprob_dict}")
assert rank_one_appears, (
f"No Logprob has rank 1"
" in the following Logprob"
f" dict: {pos_logprob_dict}"
)
# Validate logprobs detokenization
for lp_tok in pos_logprob_dict:
# Confirm that sample logprob decoded token matches
# the logprob token id at this sequence position
decoded_token = pos_logprob_dict[lp_tok].decoded_token
ref_decoded_token = _ref_convert_id_to_token(
dtv.tokenizer, lp_tok)
ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, lp_tok)
assert decoded_token == ref_decoded_token, (
f"Sampled logprob token id {lp_tok} decodes to"
f" {ref_decoded_token} but Logprob decoded"
f" token is {decoded_token} instead"
f" (at position {idx})")
f" (at position {idx})"
)
ref_cumulative_logprob += pos_logprob_dict[
sampled_token].logprob
ref_cumulative_logprob += pos_logprob_dict[sampled_token].logprob
# Assert that cumulative logprobs are correct
assert math.isclose(cumulative_logprob, ref_cumulative_logprob)
else:
@@ -254,7 +269,8 @@ def _validate_logprobs(
assert prompt_logprobs is not None, (
f"Request {req_id} requires prompt"
" logprobs but prompt logprobs are"
" None.")
" None."
)
# Require num prompt tokens to match num
# prompt logprobs
num_prompt_tokens = len(prompt_token_ids)
@@ -262,56 +278,70 @@ def _validate_logprobs(
assert num_prompt_tokens == len_prompt_logprobs, (
f"Request {req_id} has {num_prompt_tokens}"
" prompt tokens but has"
f" {len_prompt_logprobs} prompt logprobs.")
f" {len_prompt_logprobs} prompt logprobs."
)
# First prompt logprob is None
first_plp_dict = prompt_logprobs[0]
assert first_plp_dict is None, (
f"Request {req_id} first prompt logprob"
f" should be None but has following value"
f" instead: {first_plp_dict}")
f" instead: {first_plp_dict}"
)
# Break out the reference prompt log prob value &
# logprob token id matrices for the whole prompt.
# Also break out the prompt token rank vector
(ref_prompt_logprob_toks, ref_prompt_logprob_vals,
ref_prompt_token_ranks) = ref_prompt_logprobs
(
ref_prompt_logprob_toks,
ref_prompt_logprob_vals,
ref_prompt_token_ranks,
) = ref_prompt_logprobs
for idx, (prompt_token, pos_logprob_dict) in enumerate(
zip(prompt_token_ids[1:], prompt_logprobs[1:])):
zip(prompt_token_ids[1:], prompt_logprobs[1:])
):
# Break out the reference prompt log prob value
# vector, prompt logprob token id vector, and
# prompt token rank at the current position.
(ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals,
ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :],
ref_prompt_logprob_vals[idx, :],
ref_prompt_token_ranks[idx])
(
ref_pos_prompt_logprob_toks,
ref_pos_prompt_logprob_vals,
ref_pos_prompt_token_rank,
) = (
ref_prompt_logprob_toks[idx, :],
ref_prompt_logprob_vals[idx, :],
ref_prompt_token_ranks[idx],
)
# For each position in the prompt sequence,
# ensure the actual prompt token is among the
# logprobs
assert prompt_token in pos_logprob_dict, (
f"Prompt token {prompt_token} not"
f" present in logprob at index {idx}")
f"Prompt token {prompt_token} not present in logprob at index {idx}"
)
# Validate number of prompt logprobs
num_plp_toks = len(pos_logprob_dict)
assert (num_plp_toks == num_prompt_logprobs
or num_plp_toks == num_prompt_logprobs +
1), ("Valid numbers of prompt logprobs are"
f" {num_prompt_logprobs} or"
f" {num_prompt_logprobs+1} but"
f" {num_plp_toks} logprobs found at"
f" position {idx}. Logprobs dict:"
f" {pos_logprob_dict}")
assert (
num_plp_toks == num_prompt_logprobs
or num_plp_toks == num_prompt_logprobs + 1
), (
"Valid numbers of prompt logprobs are"
f" {num_prompt_logprobs} or"
f" {num_prompt_logprobs + 1} but"
f" {num_plp_toks} logprobs found at"
f" position {idx}. Logprobs dict:"
f" {pos_logprob_dict}"
)
# Validate prompt token logprob rank
prmpt_tok_lp = pos_logprob_dict[prompt_token]
prmpt_tok_lp_rank = prmpt_tok_lp.rank
ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank
assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), (
assert ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank, (
"Prompt token logprob rank"
f" {prmpt_tok_lp_rank} does not match"
" correct value"
f" {ref_prmpt_tok_lp_rank}"
f" in Logprob {prmpt_tok_lp}")
f" in Logprob {prmpt_tok_lp}"
)
# Validate that the logprob processor yields
# the correct prompt log probs and valid
@@ -325,7 +355,8 @@ def _validate_logprobs(
ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx])
assert ref_tok_id in pos_logprob_dict, (
f"Expected token {ref_tok_id} to be"
f" in logprob dict but it is not.")
f" in logprob dict but it is not."
)
# Extract actually-generated logprob
# info
@@ -335,87 +366,93 @@ def _validate_logprobs(
# A "top" (rank 1) logprob must be
# present
rank_one_appears = (True
if plp_rank == 1 else rank_one_appears)
rank_one_appears = True if plp_rank == 1 else rank_one_appears
# Rank must be >= 1
assert plp_rank >= 1, (
f"Logprob {plp} has invalid"
f" rank {plp_rank} < 1."
f" Logprob dict: {pos_logprob_dict}")
f" Logprob dict: {pos_logprob_dict}"
)
# Validate log probability
assert math.isclose(plp_val, ref_plp_val), (
f"Token id {ref_tok_id} appears in logprobs dict"
f" at position {idx} in completion with log"
f" probability {plp_val} but {ref_plp_val} was"
f" expected. Logprob: {plp}")
f" expected. Logprob: {plp}"
)
assert rank_one_appears, (f"No Logprob has rank 1"
" in the following Logprob"
f" dict: {pos_logprob_dict}")
assert rank_one_appears, (
f"No Logprob has rank 1"
" in the following Logprob"
f" dict: {pos_logprob_dict}"
)
# Validate prompt logprob detokenization
for plp_tok in pos_logprob_dict:
# Confirm that prompt logprob decoded token matches
# the logprob token id at this sequence position
decoded_token = pos_logprob_dict[plp_tok].decoded_token
ref_decoded_token = _ref_convert_id_to_token(
dtv.tokenizer, plp_tok)
ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, plp_tok)
assert decoded_token == ref_decoded_token, (
f"Prompt logprob token id {plp_tok} decodes to"
f" {ref_decoded_token} but Logprob decoded"
f" token is {decoded_token} instead"
f" (at position {idx})")
f" (at position {idx})"
)
else:
# Prompt logprobs disabled for this request
assert prompt_logprobs is None
@pytest.mark.parametrize(
"request_output_kind",
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("num_sample_logprobs",
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
@pytest.mark.parametrize("num_prompt_logprobs",
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
def test_logprobs_processor(request_output_kind: RequestOutputKind,
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
@pytest.mark.parametrize("num_prompt_logprobs", [None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
def test_logprobs_processor(
request_output_kind: RequestOutputKind,
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
dummy_test_vectors,
):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=None if num_sample_logprobs is None else
dummy_test_vectors.generation_logprobs,
generated_logprobs_raw=None
if num_sample_logprobs is None
else dummy_test_vectors.generation_logprobs,
prompt_logprobs_raw=None
if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs)
if num_prompt_logprobs is None
else dummy_test_vectors.prompt_logprobs,
)
# Make N requests.
request_id_list = [
f"request-{idx}"
for idx in range(len(dummy_test_vectors.prompt_strings))
f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings))
]
requests = [
EngineCoreRequest(request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs,
),
pooling_params=None)
EngineCoreRequest(
request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs,
),
pooling_params=None,
)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
@@ -446,7 +483,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
prompt_logprobs = request_output.prompt_logprobs
logprobs = request_output.outputs[0].logprobs
gen_cumulative_logprobs[request_id] = request_output.outputs[
0].cumulative_logprob
0
].cumulative_logprob
if request_id not in gen_logprobs:
# Start tracking sample and prompt logprobs for this request
gen_tokens[request_id] = new_tokens
@@ -463,10 +501,16 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
plp.extend(prompt_logprobs)
# Confirmed tracked logprobs match what we expect
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
gen_cumulative_logprobs, dummy_test_vectors,
request_id_list, num_sample_logprobs,
num_prompt_logprobs)
_validate_logprobs(
gen_tokens,
gen_logprobs,
gen_prompt_logprobs,
gen_cumulative_logprobs,
dummy_test_vectors,
request_id_list,
num_sample_logprobs,
num_prompt_logprobs,
)
assert output_processor.get_num_unfinished_requests() == 0
assert not output_processor.has_unfinished_requests()
@@ -474,15 +518,23 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
@pytest.mark.parametrize(
"include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs",
[(False, "stop_token_ids", False, None),
(True, "stop_token_ids", False, None),
(False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
(True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
(False, "eos_token_id", False, None), (True, "eos_token_id", False, None),
(False, "eos_token_id", True, None)])
def test_stop_token(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int], stop_token_type: str,
ignore_eos: bool, dummy_test_vectors):
[
(False, "stop_token_ids", False, None),
(True, "stop_token_ids", False, None),
(False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
(True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
(False, "eos_token_id", False, None),
(True, "eos_token_id", False, None),
(False, "eos_token_id", True, None),
],
)
def test_stop_token(
include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int],
stop_token_type: str,
ignore_eos: bool,
dummy_test_vectors,
):
"""Test output processor EOS/stop token handling.
Send mock engine core request to mock engine core and pass core outputs
@@ -523,9 +575,10 @@ def test_stop_token(include_stop_str_in_output: bool,
dummy_test_vectors: dummy engine core outputs and other data structures
"""
model_id = dummy_test_vectors.tokenizer.name_or_path
if model_id != 'meta-llama/Llama-3.2-1B':
raise AssertionError("Test requires meta-llama/Llama-3.2-1B but "
f"{model_id} is in use.")
if model_id != "meta-llama/Llama-3.2-1B":
raise AssertionError(
f"Test requires meta-llama/Llama-3.2-1B but {model_id} is in use."
)
do_logprobs = num_sample_logprobs is not None
# EOS under test; if False, stop_token_ids under test
is_eos_test = stop_token_type == "eos_token_id"
@@ -536,18 +589,16 @@ def test_stop_token(include_stop_str_in_output: bool,
) # '<|end_of_text|>'
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
# Dummy engine core outputs, with control tokens suffixed to test stops
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
suffix_token = [eos_token_id] if is_eos_test else stop_token_ids
assert suffix_token is not None and isinstance(suffix_token[0], int)
generation_string = dummy_test_vectors.generation_strings[0]
generation_tokens = (dummy_test_vectors.generation_tokens[0] +
2 * suffix_token)
generation_tokens = dummy_test_vectors.generation_tokens[0] + 2 * suffix_token
if do_logprobs:
generation_logprobs = (
dummy_test_vectors.generation_logprobs[0] +
2 * [dummy_test_vectors.generation_logprobs[0][-1]])
generation_logprobs = dummy_test_vectors.generation_logprobs[0] + 2 * [
dummy_test_vectors.generation_logprobs[0][-1]
]
prompt_string = dummy_test_vectors.prompt_strings[0]
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
engine_core = MockEngineCore(
@@ -556,7 +607,8 @@ def test_stop_token(include_stop_str_in_output: bool,
prompt_logprobs_raw=None,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos)
ignore_eos=ignore_eos,
)
# Make request.
request_id = "request-0"
@@ -580,7 +632,8 @@ def test_stop_token(include_stop_str_in_output: bool,
prompt_logprobs=None,
ignore_eos=ignore_eos,
),
pooling_params=None)
pooling_params=None,
)
# Add request to the detokenizer.
output_processor.add_request(request, prompt_string)
@@ -605,7 +658,7 @@ def test_stop_token(include_stop_str_in_output: bool,
# Update tracking.
request_output = request_outputs[0]
if request_output.finished:
finish_reason = ("length" if is_eos_ignore_test else "stop")
finish_reason = "length" if is_eos_ignore_test else "stop"
assert request_output.outputs[0].finish_reason == finish_reason
gen_string += request_output.outputs[0].text
@@ -614,7 +667,7 @@ def test_stop_token(include_stop_str_in_output: bool,
gen_logprobs.extend(request_output.outputs[0].logprobs)
# Validate generated text
control_token = '<|end_of_text|>' if is_eos_test else '<|eot_id|>'
control_token = "<|end_of_text|>" if is_eos_test else "<|eot_id|>"
if is_eos_ignore_test:
# Length-based stop; expect full string
ref_str = generation_string + 2 * control_token
@@ -624,14 +677,15 @@ def test_stop_token(include_stop_str_in_output: bool,
else:
# Stop token triggered but not in output
ref_str = generation_string
assert gen_string == ref_str, (f"{gen_string=}, {ref_str=}")
assert gen_string == ref_str, f"{gen_string=}, {ref_str=}"
if do_logprobs:
# Validate number of sample logprobs
num_tokens = len(gen_tokens)
num_logprobs = len(gen_logprobs)
assert num_tokens == num_logprobs, (
f"Token count ({num_tokens}) != logprobs count ({num_logprobs})")
f"Token count ({num_tokens}) != logprobs count ({num_logprobs})"
)
# Check requests are finished
assert output_processor.get_num_unfinished_requests() == 0
@@ -639,22 +693,24 @@ def test_stop_token(include_stop_str_in_output: bool,
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
@pytest.mark.parametrize("num_sample_logprobs",
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
def test_stop_string(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int], dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=False)
@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
def test_stop_string(
include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int],
dummy_test_vectors,
):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs else None,
prompt_logprobs_raw=None)
if num_sample_logprobs
else None,
prompt_logprobs_raw=None,
)
# Make N requests.
request_id_list = [
f"request-{idx}"
for idx in range(len(dummy_test_vectors.prompt_strings))
f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings))
]
requests = [
EngineCoreRequest(
@@ -675,7 +731,8 @@ def test_stop_string(include_stop_str_in_output: bool,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
),
pooling_params=None)
pooling_params=None,
)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
@@ -715,7 +772,8 @@ def test_stop_string(include_stop_str_in_output: bool,
prompt_logprobs = request_output.prompt_logprobs
logprobs = request_output.outputs[0].logprobs
gen_cumulative_logprobs[request_id] = request_output.outputs[
0].cumulative_logprob
0
].cumulative_logprob
if request_id not in gen_strings:
gen_strings[request_id] = new_text
gen_tokens[request_id] = new_tokens
@@ -733,8 +791,8 @@ def test_stop_string(include_stop_str_in_output: bool,
# Confirmed tracked values matches what we expected.
for idx, (ref_gen_str, stop_str) in enumerate(
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)):
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
):
# Request should be aborted.
request_id = f"request-{idx}"
assert request_id in aborted
@@ -748,24 +806,28 @@ def test_stop_string(include_stop_str_in_output: bool,
ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str
if include_stop_str_in_output:
assert gen_str == ref_str_inc_stop, (
f"{gen_str=}, {ref_str_inc_stop=}")
assert gen_str == ref_str_inc_stop, f"{gen_str=}, {ref_str_inc_stop=}"
else:
assert gen_str == ref_str_exc_stop, (
f"{gen_str=}, {ref_str_exc_stop=}")
assert gen_str == ref_str_exc_stop, f"{gen_str=}, {ref_str_exc_stop=}"
# Confirmed tracked logprobs match what we expect
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
gen_cumulative_logprobs, dummy_test_vectors,
request_id_list, num_sample_logprobs, None)
_validate_logprobs(
gen_tokens,
gen_logprobs,
gen_prompt_logprobs,
gen_cumulative_logprobs,
dummy_test_vectors,
request_id_list,
num_sample_logprobs,
None,
)
assert output_processor.get_num_unfinished_requests() == 0
assert not output_processor.has_unfinished_requests()
def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=True)
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
@@ -782,7 +844,8 @@ def test_iteration_stats(dummy_test_vectors):
data_parallel_rank=None,
sampling_params=SamplingParams(),
pooling_params=None,
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
# Add all requests except one to the OutputProcessor.
@@ -794,12 +857,13 @@ def test_iteration_stats(dummy_test_vectors):
# First iteration has 2 prefills.
outputs = engine_core.get_outputs()[:num_active]
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
total_prompt_tokens = sum([
len(prompt_tokens)
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
])
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
total_prompt_tokens = sum(
[
len(prompt_tokens)
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
]
)
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
assert iteration_stats.num_generation_tokens == num_active
@@ -807,8 +871,7 @@ def test_iteration_stats(dummy_test_vectors):
# Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active]
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
assert iteration_stats.num_prompt_tokens == 0
assert iteration_stats.num_generation_tokens == num_active
@@ -818,8 +881,7 @@ def test_iteration_stats(dummy_test_vectors):
num_active += 1
outputs = engine_core.get_outputs()[:num_active]
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
@@ -828,8 +890,7 @@ def test_iteration_stats(dummy_test_vectors):
# Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active]
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
assert iteration_stats.num_prompt_tokens == 0
assert iteration_stats.num_generation_tokens == num_active
@@ -853,16 +914,13 @@ async def test_request_output_collector():
text=TEXT,
token_ids=[idx],
cumulative_logprob=(idx + 1 * 1.0),
logprobs=[{
"a": idx,
"b": idx
}],
finish_reason="length" if
(idx == NUM_REQS - 1) else None,
logprobs=[{"a": idx, "b": idx}],
finish_reason="length" if (idx == NUM_REQS - 1) else None,
)
],
finished=(idx == NUM_REQS - 1),
) for idx in range(NUM_REQS)
)
for idx in range(NUM_REQS)
]
collector = RequestOutputCollector(RequestOutputKind.DELTA)
@@ -888,8 +946,7 @@ async def test_request_output_collector():
assert not output.finished
# Text, token_ids, and logprobs should get merged.
assert output.outputs[0].text == TEXT * num_to_put
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
list(range(num_to_put))):
for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))):
assert tok_0 == tok_1
assert len(output.outputs[0].logprobs) == num_to_put
@@ -910,8 +967,7 @@ async def test_request_output_collector():
assert output.outputs[0].finish_reason == "length"
# Text, token_ids, and logprobs should get merged.
assert output.outputs[0].text == TEXT * num_to_put
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
list(range(num_to_put))):
for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))):
assert tok_0 == tok_1
assert len(output.outputs[0].logprobs) == num_to_put
@@ -1003,8 +1059,7 @@ async def test_cumulative_output_collector_n():
@pytest.mark.parametrize("runner", ["generate", "pooling"])
def test_abort_requests(runner: str, dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=True)
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
@@ -1016,9 +1071,9 @@ def test_abort_requests(runner: str, dummy_test_vectors):
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams() if runner == "generate" else None,
pooling_params=PoolingParams(
task="embed") if runner == "pooling" else None,
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
pooling_params=PoolingParams(task="embed") if runner == "pooling" else None,
)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
for request in requests:

View File

@@ -16,35 +16,33 @@ baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
# Mock processor for testing
def _mk_processor(monkeypatch,
*,
mm_cache_gb: float = 4.0,
enable_prefix_caching: bool = True) -> Processor:
def _mk_processor(
monkeypatch, *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
) -> Processor:
"""
Create a Processor instance with minimal configuration suitable for unit
tests without accessing external resources.
"""
monkeypatch.setattr(ModelConfig,
"try_get_generation_config",
lambda self: {},
raising=True)
monkeypatch.setattr(ModelConfig,
"__post_init__",
lambda self, *args: None,
raising=True)
monkeypatch.setattr(ModelConfig,
"verify_with_parallel_config",
lambda self, parallel_config: None,
raising=True)
monkeypatch.setattr(processor_mod,
"processor_cache_from_config",
lambda vllm_config, mm_registry: None,
raising=True)
monkeypatch.setattr(
ModelConfig, "try_get_generation_config", lambda self: {}, raising=True
)
monkeypatch.setattr(
ModelConfig, "__post_init__", lambda self, *args: None, raising=True
)
monkeypatch.setattr(
ModelConfig,
"verify_with_parallel_config",
lambda self, parallel_config: None,
raising=True,
)
monkeypatch.setattr(
processor_mod,
"processor_cache_from_config",
lambda vllm_config, mm_registry: None,
raising=True,
)
monkeypatch.setattr(VllmConfig,
"__post_init__",
lambda self: None,
raising=True)
monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True)
model_config = ModelConfig(
skip_tokenizer_init=True,
@@ -57,12 +55,10 @@ def _mk_processor(monkeypatch,
# Minimal multimodal_config to satisfy references in
# Processor.process_inputs.
class _MockMMConfig:
def __init__(self, gb: float):
self.mm_processor_cache_gb = gb
model_config.multimodal_config = _MockMMConfig(
mm_cache_gb) # type: ignore[attr-defined]
model_config.multimodal_config = _MockMMConfig(mm_cache_gb) # type: ignore[attr-defined]
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
@@ -79,13 +75,9 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
prompt = {
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image]
},
"multi_modal_data": {"image": [cherry_pil_image, stop_pil_image]},
# Mismatch: 2 items but only 1 uuid provided
"multi_modal_uuids": {
"image": ["hash_cherry"]
},
"multi_modal_uuids": {"image": ["hash_cherry"]},
}
with pytest.raises(ValueError, match="must have same length as data"):
@@ -104,16 +96,13 @@ def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
# Two modalities provided in data
"multi_modal_data": {
"image": [cherry_pil_image],
"video": [baby_reading_np_ndarrays]
"video": [baby_reading_np_ndarrays],
},
# Only image uuids provided; video missing should raise
"multi_modal_uuids": {
"image": ["hash_cherry"]
},
"multi_modal_uuids": {"image": ["hash_cherry"]},
}
with pytest.raises(ValueError,
match="must be provided if multi_modal_data"):
with pytest.raises(ValueError, match="must be provided if multi_modal_data"):
processor.process_inputs(
request_id="req-2",
prompt=prompt, # type: ignore[arg-type]
@@ -130,28 +119,28 @@ def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
],
)
def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool):
processor = _mk_processor(monkeypatch,
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching)
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
processor = _mk_processor(
monkeypatch,
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching,
)
# Capture the overrides passed to InputPreprocessor.preprocess
captured: dict[str, object] = {}
def fake_preprocess(prompt,
*,
tokenization_kwargs=None,
lora_request=None,
mm_uuids=None):
def fake_preprocess(
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
):
captured["mm_uuids"] = mm_uuids
# Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]}
# Monkeypatch only the bound preprocess method on this instance
monkeypatch.setattr(processor.input_preprocessor,
"preprocess",
fake_preprocess,
raising=True)
monkeypatch.setattr(
processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
)
# Use a consistent two-image scenario across all configurations
mm_uuids = {"image": [None, "hash_stop"], "video": None}
@@ -176,24 +165,19 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
processor = _mk_processor(monkeypatch,
mm_cache_gb=0.0,
enable_prefix_caching=False)
processor = _mk_processor(monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False)
captured: dict[str, object] = {}
def fake_preprocess(prompt,
*,
tokenization_kwargs=None,
lora_request=None,
mm_uuids=None):
def fake_preprocess(
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
):
captured["mm_uuids"] = mm_uuids
return {"type": "token", "prompt_token_ids": [1]}
monkeypatch.setattr(processor.input_preprocessor,
"preprocess",
fake_preprocess,
raising=True)
monkeypatch.setattr(
processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
)
request_id = "req-42"
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"}

View File

@@ -82,11 +82,12 @@ def _create_random_top_logprob_test_matrix(
def _create_random_top_token_test_vector(
num_logprobs: int,
lower: int,
upper: int,
sampled_token_id: int,
adjust_num_logprobs: bool = True) -> tuple[torch.Tensor, int]:
num_logprobs: int,
lower: int,
upper: int,
sampled_token_id: int,
adjust_num_logprobs: bool = True,
) -> tuple[torch.Tensor, int]:
"""Create a random vector of top logprob token indices
Use to create fake sample logprobs for testing. The sampled token
@@ -127,8 +128,9 @@ def _create_random_top_token_test_vector(
# Check if the sampled_token_id occurs in choice_tensor[1:]
if sampled_token_id in choice_tensor[1:]:
sampled_token_rank = (choice_tensor[1:] == sampled_token_id).nonzero(
as_tuple=True)[0].item()
sampled_token_rank = (
(choice_tensor[1:] == sampled_token_id).nonzero(as_tuple=True)[0].item()
)
else:
# If not found, assign a random int between num_logprobs and 50700
sampled_token_rank = random.randint(num_logprobs, 50700)
@@ -164,9 +166,12 @@ def _create_random_top_token_test_matrix(
num_elements = shape[0] * shape[1]
choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower
matrix = torch.cat(
(torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1),
choice_tensor.view(shape)),
dim=1)
(
torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1),
choice_tensor.view(shape),
),
dim=1,
)
# Initialize the tensor for storing the ranks
prompt_token_ranks = torch.empty(shape[0], dtype=torch.int)
@@ -174,8 +179,7 @@ def _create_random_top_token_test_matrix(
# Iterate over each row to check presence of
# tokens_list[rdx] and determine its index
for rdx in range(shape[0]):
row = matrix[rdx,
1:] # Skip the first column as it contains the token list
row = matrix[rdx, 1:] # Skip the first column as it contains the token list
token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0]
if token_index.numel() > 0:
prompt_token_ranks[rdx] = token_index.item()
@@ -229,19 +233,21 @@ def generate_dummy_sample_logprobs(
(
token_vector,
sampled_token_rank,
) = _create_random_top_token_test_vector(num_logprobs, 0,
len(tokenizer.vocab) - 1,
sampled_token_id)
) = _create_random_top_token_test_vector(
num_logprobs, 0, len(tokenizer.vocab) - 1, sampled_token_id
)
res.append(
(token_vector,
_create_random_top_logprob_test_vector(num_logprobs + 1, -100,
0), sampled_token_rank))
(
token_vector,
_create_random_top_logprob_test_vector(num_logprobs + 1, -100, 0),
sampled_token_rank,
)
)
# Convert tensors in the list tuples to Python lists
res_list_format = [
(log_probs_tensor.tolist(), token_ids_tensor.tolist(),
sampled_token_rank)
(log_probs_tensor.tolist(), token_ids_tensor.tolist(), sampled_token_rank)
for log_probs_tensor, token_ids_tensor, sampled_token_rank in res
]
@@ -282,18 +288,24 @@ def generate_dummy_prompt_logprobs_tensors(
token_vector,
prompt_token_ranks,
) = _create_random_top_token_test_matrix(
(num_prompt_logprobs, num_logprobs), 0,
len(tokenizer.vocab) - 1, prompt_tokens_list[1:])
(num_prompt_logprobs, num_logprobs),
0,
len(tokenizer.vocab) - 1,
prompt_tokens_list[1:],
)
return LogprobsTensors(
token_vector,
_create_random_top_logprob_test_matrix(
(num_prompt_logprobs, num_logprobs + 1), -100, 0),
prompt_token_ranks)
(num_prompt_logprobs, num_logprobs + 1), -100, 0
),
prompt_token_ranks,
)
@dataclass
class DummyOutputProcessorTestVectors:
"""Dummy test vectors for output processor tests"""
tokenizer: GeneralTokenizerType
vllm_config: EngineArgs
full_tokens: list[list[int]] # Prompt + generated tokens
@@ -320,9 +332,9 @@ class MockEngineCore:
# For each request, for each sampled token offset,
# a tuple of
# (list of topk token ids, list of sample logprob vals, rank)
generated_logprobs_raw: Optional[list[list[tuple[list[int],
list[float],
int]]]] = None,
generated_logprobs_raw: Optional[
list[list[tuple[list[int], list[float], int]]]
] = None,
# For each request, a tuple of
# (prompt logprob val matrix, prompt logprob tok id matrix);
# each matrix has dimensions
@@ -355,7 +367,8 @@ class MockEngineCore:
if do_logprobs:
assert self.generated_logprobs_raw is not None
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
self.generated_logprobs_raw[req_idx][token_idx])
self.generated_logprobs_raw[req_idx][token_idx]
)
logprobs = LogprobsLists(
[logprobs_token_ids_],
[logprobs_],