Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,12 +5,15 @@ import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN,
|
||||
TOKENIZER_NAME,
|
||||
DummyOutputProcessorTestVectors,
|
||||
generate_dummy_prompt_logprobs_tensors,
|
||||
generate_dummy_sample_logprobs)
|
||||
from tests.v1.engine.utils import (
|
||||
NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
PROMPT_LEN,
|
||||
TOKENIZER_NAME,
|
||||
DummyOutputProcessorTestVectors,
|
||||
generate_dummy_prompt_logprobs_tensors,
|
||||
generate_dummy_sample_logprobs,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
from ...distributed.conftest import publisher_config, random_port # noqa: F401
|
||||
@@ -31,9 +34,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
|
||||
# Tokenize prompts under test & create dummy generated tokens
|
||||
prompt_tokens = [
|
||||
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
|
||||
]
|
||||
prompt_tokens = [tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS]
|
||||
generation_tokens = [
|
||||
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
|
||||
]
|
||||
@@ -42,9 +43,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
|
||||
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
|
||||
for prompt_tokens in prompt_tokens
|
||||
]
|
||||
prompt_strings_len = [
|
||||
len(prompt_string) for prompt_string in prompt_strings
|
||||
]
|
||||
prompt_strings_len = [len(prompt_string) for prompt_string in prompt_strings]
|
||||
return DummyOutputProcessorTestVectors(
|
||||
tokenizer=tokenizer,
|
||||
vllm_config=vllm_config,
|
||||
@@ -58,7 +57,8 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
|
||||
for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len)
|
||||
],
|
||||
prompt_logprobs=[],
|
||||
generation_logprobs=[])
|
||||
generation_logprobs=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -76,12 +76,16 @@ def dummy_test_vectors() -> DummyOutputProcessorTestVectors:
|
||||
generate_dummy_sample_logprobs(
|
||||
sampled_tokens_list=tokens_list,
|
||||
num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens
|
||||
tokenizer=dtv.tokenizer,
|
||||
)
|
||||
for tokens_list in dtv.generation_tokens
|
||||
]
|
||||
dtv.prompt_logprobs = [
|
||||
generate_dummy_prompt_logprobs_tensors(
|
||||
prompt_tokens_list=tokens_list,
|
||||
num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens
|
||||
tokenizer=dtv.tokenizer,
|
||||
)
|
||||
for tokens_list in dtv.prompt_tokens
|
||||
]
|
||||
return dtv
|
||||
|
||||
@@ -21,16 +21,16 @@ from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.metrics.loggers import LoggingStatLogger
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
||||
|
||||
TEXT_ENGINE_ARGS = AsyncEngineArgs(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
|
||||
enforce_eager=True)
|
||||
VISION_ENGINE_ARGS = AsyncEngineArgs(
|
||||
model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True
|
||||
)
|
||||
|
||||
TEXT_PROMPT = "Hello my name is Robert and"
|
||||
|
||||
@@ -38,12 +38,11 @@ VISION_PROMPT_TEMPLATE = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
|
||||
"\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
"What is in the image?<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
VISION_PROMPT = {
|
||||
"prompt": VISION_PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image
|
||||
},
|
||||
"multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
|
||||
}
|
||||
|
||||
|
||||
@@ -70,10 +69,9 @@ async def generate(
|
||||
n=n,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
async for out in engine.generate(request_id=request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params):
|
||||
|
||||
async for out in engine.generate(
|
||||
request_id=request_id, prompt=prompt, sampling_params=sampling_params
|
||||
):
|
||||
num_tokens = sum(len(output.token_ids) for output in out.outputs)
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
count += num_tokens
|
||||
@@ -89,7 +87,8 @@ async def generate(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
@@ -121,25 +120,29 @@ async def test_load(
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
generate(
|
||||
engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_EXCEPTION)
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for task in done:
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}")
|
||||
f"expected {NUM_EXPECTED_TOKENS}"
|
||||
)
|
||||
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
@@ -151,7 +154,6 @@ async def test_abort(
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -170,14 +172,17 @@ async def test_abort(
|
||||
# Create concurrent requests.
|
||||
tasks: list[asyncio.Task] = []
|
||||
for idx, request_id in enumerate(request_ids):
|
||||
max_tokens = (NUM_EXPECTED_TOKENS_LONG if
|
||||
(idx
|
||||
in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
|
||||
max_tokens = (
|
||||
NUM_EXPECTED_TOKENS_LONG
|
||||
if (idx in REQUEST_IDS_TO_ABORT)
|
||||
else NUM_EXPECTED_TOKENS
|
||||
)
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind,
|
||||
max_tokens, n)))
|
||||
generate(engine, request_id, prompt, output_kind, max_tokens, n)
|
||||
)
|
||||
)
|
||||
|
||||
# API server cancels requests when they disconnect.
|
||||
for idx in REQUEST_IDS_TO_ABORT:
|
||||
@@ -197,7 +202,8 @@ async def test_abort(
|
||||
expected_tokens = NUM_EXPECTED_TOKENS * n
|
||||
assert num_generated_tokens == expected_tokens, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {expected_tokens}")
|
||||
f"expected {expected_tokens}"
|
||||
)
|
||||
|
||||
# Make sure all aborted requests were really aborted.
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
@@ -205,21 +211,21 @@ async def test_abort(
|
||||
# Confirm we can do another generation.
|
||||
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
|
||||
task = asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind,
|
||||
NUM_EXPECTED_TOKENS))
|
||||
generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS)
|
||||
)
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_abort(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
output_kind: RequestOutputKind,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -238,14 +244,19 @@ async def test_multi_abort(
|
||||
# Create concurrent requests.
|
||||
tasks: list[asyncio.Task] = []
|
||||
for idx, request_id in enumerate(request_ids):
|
||||
max_tokens = (NUM_EXPECTED_TOKENS_LONG if
|
||||
(idx
|
||||
in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
|
||||
max_tokens = (
|
||||
NUM_EXPECTED_TOKENS_LONG
|
||||
if (idx in REQUEST_IDS_TO_ABORT)
|
||||
else NUM_EXPECTED_TOKENS
|
||||
)
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, TEXT_PROMPT, output_kind,
|
||||
max_tokens, n)))
|
||||
generate(
|
||||
engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Let requests start
|
||||
await asyncio.sleep(0.5)
|
||||
@@ -261,25 +272,26 @@ async def test_multi_abort(
|
||||
for idx, result in enumerate(results):
|
||||
if idx in REQUEST_IDS_TO_ABORT:
|
||||
# Aborted requests should return partial results
|
||||
assert isinstance(
|
||||
result, tuple
|
||||
), f"Request {idx} should have completed with partial results"
|
||||
assert isinstance(result, tuple), (
|
||||
f"Request {idx} should have completed with partial results"
|
||||
)
|
||||
num_generated_tokens, request_id = result
|
||||
# Should have generated some tokens before abort
|
||||
assert num_generated_tokens > 0, (
|
||||
f"Aborted request "
|
||||
f"{request_id} should have generated some tokens")
|
||||
f"Aborted request {request_id} should have generated some tokens"
|
||||
)
|
||||
else:
|
||||
# Non-aborted requests should complete normally
|
||||
assert isinstance(
|
||||
result,
|
||||
tuple), f"Request {idx} should have completed successfully"
|
||||
assert isinstance(result, tuple), (
|
||||
f"Request {idx} should have completed successfully"
|
||||
)
|
||||
num_generated_tokens, request_id = result
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
expected_tokens = NUM_EXPECTED_TOKENS * n
|
||||
assert num_generated_tokens == expected_tokens, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {expected_tokens}")
|
||||
f"expected {expected_tokens}"
|
||||
)
|
||||
|
||||
# Make sure all aborted requests were cleaned up
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
@@ -297,7 +309,6 @@ async def test_finished_flag(
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -314,9 +325,9 @@ async def test_finished_flag(
|
||||
)
|
||||
outputs = [
|
||||
out
|
||||
async for out in engine.generate(request_id="request-33",
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params)
|
||||
async for out in engine.generate(
|
||||
request_id="request-33", prompt=prompt, sampling_params=sampling_params
|
||||
)
|
||||
]
|
||||
|
||||
# Assert only the last output has the finished flag set
|
||||
@@ -329,9 +340,9 @@ async def test_finished_flag(
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType):
|
||||
async def test_mid_stream_cancellation(
|
||||
monkeypatch: pytest.MonkeyPatch, engine_args: AsyncEngineArgs, prompt: PromptType
|
||||
):
|
||||
"""Test that requests can be cancelled mid-stream."""
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
@@ -358,7 +369,9 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
|
||||
RequestOutputKind.DELTA,
|
||||
NUM_TOKENS,
|
||||
cancel_after=NUM_EXPECTED_TOKENS,
|
||||
)))
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
@@ -367,7 +380,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
|
||||
for num_generated_tokens, request_id in results:
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
f"{request_id} generated {num_generated_tokens} tokens but "
|
||||
f"expected to cancel after {NUM_EXPECTED_TOKENS}")
|
||||
f"expected to cancel after {NUM_EXPECTED_TOKENS}"
|
||||
)
|
||||
|
||||
# Make sure no requests are left hanging
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
@@ -375,15 +389,16 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
|
||||
# Confirm we can reuse the request id after the cancellations.
|
||||
request_id = request_ids[0]
|
||||
task = asyncio.create_task(
|
||||
generate(engine, request_id, prompt, RequestOutputKind.DELTA,
|
||||
NUM_EXPECTED_TOKENS))
|
||||
generate(
|
||||
engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS
|
||||
)
|
||||
)
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
class MockLoggingStatLogger(LoggingStatLogger):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
super().__init__(vllm_config, engine_index)
|
||||
self.log = MagicMock()
|
||||
@@ -410,8 +425,7 @@ async def test_customize_loggers(monkeypatch):
|
||||
|
||||
stat_loggers = engine.logger_manager.per_engine_logger_dict
|
||||
assert len(stat_loggers) == 1
|
||||
assert len(
|
||||
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
|
||||
assert len(stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@@ -424,24 +438,30 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=100,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
temperature=1.0,
|
||||
seed=33)
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=100,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
temperature=1.0,
|
||||
seed=33,
|
||||
)
|
||||
|
||||
# Test with valid DP rank.
|
||||
async for _ in engine.generate(request_id="request-34",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=0):
|
||||
async for _ in engine.generate(
|
||||
request_id="request-34",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=0,
|
||||
):
|
||||
pass
|
||||
|
||||
# Test with out-of-range DP rank.
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in engine.generate(request_id="request-35",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=1):
|
||||
async for _ in engine.generate(
|
||||
request_id="request-35",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@@ -465,10 +485,14 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch):
|
||||
await engine.check_health()
|
||||
|
||||
# Test 2: Mock the errored property to simulate a dead engine
|
||||
with patch.object(type(engine),
|
||||
'errored',
|
||||
new_callable=lambda: property(lambda self: True)
|
||||
), pytest.raises(EngineDeadError):
|
||||
with (
|
||||
patch.object(
|
||||
type(engine),
|
||||
"errored",
|
||||
new_callable=lambda: property(lambda self: True),
|
||||
),
|
||||
pytest.raises(EngineDeadError),
|
||||
):
|
||||
await engine.check_health()
|
||||
|
||||
# Test 3: Verify healthy engine still works after mock
|
||||
@@ -476,7 +500,8 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_final_output(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -504,8 +529,8 @@ async def test_abort_final_output(
|
||||
|
||||
outputs: list[RequestOutput] = []
|
||||
generated = asyncio.create_task(
|
||||
collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params,
|
||||
outputs))
|
||||
collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs)
|
||||
)
|
||||
|
||||
# Let it generate some tokens
|
||||
await asyncio.sleep(0.5)
|
||||
@@ -525,14 +550,13 @@ async def test_abort_final_output(
|
||||
assert final_output.outputs[0].stop_reason is None
|
||||
|
||||
# Verify num_cached_tokens is set correctly
|
||||
assert hasattr(final_output, 'num_cached_tokens')
|
||||
assert hasattr(final_output, "num_cached_tokens")
|
||||
assert final_output.num_cached_tokens >= 0
|
||||
|
||||
# If we got intermediate outputs, verify they are consistent
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
# For DELTA, sum all intermediate tokens should <= final tokens
|
||||
token_count = sum(
|
||||
len(output.outputs[0].token_ids) for output in outputs)
|
||||
token_count = sum(len(output.outputs[0].token_ids) for output in outputs)
|
||||
assert token_count > 0
|
||||
# This would ordinarily be 0, but could end up > 0 if the
|
||||
# final abort is coalesced with another chunk in the output queue.
|
||||
@@ -554,9 +578,9 @@ async def collect_outputs(
|
||||
) -> Optional[RequestOutput]:
|
||||
"""Helper to collect outputs and return the final one."""
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine.generate(request_id=request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params):
|
||||
async for output in engine.generate(
|
||||
request_id=request_id, prompt=prompt, sampling_params=sampling_params
|
||||
):
|
||||
if not output.finished:
|
||||
outputs_list.append(output)
|
||||
final_output = output
|
||||
|
||||
@@ -22,8 +22,9 @@ def test_prefix_caching_from_cli():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
args = parser.parse_args([])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert (vllm_config.cache_config.enable_prefix_caching
|
||||
), "V1 turns on prefix caching by default."
|
||||
assert vllm_config.cache_config.enable_prefix_caching, (
|
||||
"V1 turns on prefix caching by default."
|
||||
)
|
||||
|
||||
# Turn it off possible with flag.
|
||||
args = parser.parse_args(["--no-enable-prefix-caching"])
|
||||
@@ -41,8 +42,7 @@ def test_prefix_caching_from_cli():
|
||||
# set hash algorithm to sha256_cbor
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == \
|
||||
"sha256_cbor"
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256_cbor"
|
||||
|
||||
# set hash algorithm to sha256
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
|
||||
@@ -57,10 +57,10 @@ def test_prefix_caching_from_cli():
|
||||
|
||||
def test_defaults_with_usage_context():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config: VllmConfig = engine_args.create_engine_config(
|
||||
UsageContext.LLM_CLASS)
|
||||
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device_name = current_platform.get_device_name().lower()
|
||||
if "h100" in device_name or "h200" in device_name:
|
||||
# For H100 and H200, we use larger default values.
|
||||
@@ -76,7 +76,6 @@ def test_defaults_with_usage_context():
|
||||
assert vllm_config.scheduler_config.max_num_batched_tokens == default_llm_tokens # noqa: E501
|
||||
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.OPENAI_API_SERVER)
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER)
|
||||
assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
|
||||
assert vllm_config.scheduler_config.max_num_batched_tokens == default_server_tokens # noqa: E501
|
||||
|
||||
@@ -22,8 +22,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from ...utils import create_new_process_for_each_test, multi_gpu_test
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
@@ -48,7 +47,6 @@ def make_request() -> EngineCoreRequest:
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
"""Setup the EngineCore."""
|
||||
@@ -57,14 +55,13 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
"""Test basic request lifecycle."""
|
||||
|
||||
# First request.
|
||||
engine_core.add_request(
|
||||
*engine_core.preprocess_add_request(make_request()))
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
@@ -73,8 +70,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
|
||||
# Second request.
|
||||
engine_core.add_request(
|
||||
*engine_core.preprocess_add_request(make_request()))
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
|
||||
@@ -83,10 +79,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
# Add two requests in a row.
|
||||
engine_core.add_request(
|
||||
*engine_core.preprocess_add_request(make_request()))
|
||||
engine_core.add_request(
|
||||
*engine_core.preprocess_add_request(make_request()))
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
|
||||
assert len(engine_core.scheduler.waiting) == 2
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
@@ -196,9 +190,9 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
"""Test basic request lifecycle."""
|
||||
# First request.
|
||||
request: EngineCoreRequest = make_request()
|
||||
@@ -238,17 +232,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
Test that the engine can handle multiple concurrent batches.
|
||||
"""
|
||||
|
||||
def make_request_with_max_tokens(req_id: str,
|
||||
max_tokens: int) -> EngineCoreRequest:
|
||||
def make_request_with_max_tokens(req_id: str, max_tokens: int) -> EngineCoreRequest:
|
||||
request = make_request()
|
||||
request.request_id = req_id
|
||||
request.sampling_params.max_tokens = max_tokens
|
||||
return request
|
||||
|
||||
class DummyExecutor(UniProcExecutor):
|
||||
|
||||
def initialize_from_config(
|
||||
self, kv_cache_configs: list[KVCacheConfig]) -> None:
|
||||
def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
|
||||
super().initialize_from_config(kv_cache_configs)
|
||||
|
||||
# Create a thread pool with a single worker
|
||||
@@ -265,8 +256,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
assert non_block
|
||||
|
||||
def _execute():
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
output = self.collective_rpc("execute_model", args=(scheduler_output,))
|
||||
# Make a copy because output[0] may be reused
|
||||
# by the next batch.
|
||||
return copy.deepcopy(output[0])
|
||||
@@ -279,7 +269,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
return 2
|
||||
|
||||
def shutdown(self):
|
||||
if hasattr(self, 'thread_pool'):
|
||||
if hasattr(self, "thread_pool"):
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
@@ -297,9 +287,9 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
log_stats=False,
|
||||
executor_class=DummyExecutor)
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, log_stats=False, executor_class=DummyExecutor
|
||||
)
|
||||
assert engine_core.batch_queue is not None
|
||||
|
||||
# Add two requests in a row. Each request have 12 prompt tokens.
|
||||
@@ -314,8 +304,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["0"] == 10
|
||||
# num_computed_tokens should have been updated immediately.
|
||||
assert engine_core.scheduler.requests[
|
||||
req0.request_id].num_computed_tokens == 10
|
||||
assert engine_core.scheduler.requests[req0.request_id].num_computed_tokens == 10
|
||||
|
||||
# Schedule Batch 2: (2, req0), (8, req1)
|
||||
assert engine_core.step_with_batch_queue()[0] == {}
|
||||
@@ -366,8 +355,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
assert output is not None
|
||||
assert len(output[0].outputs) == 1
|
||||
if req_id in engine_core.scheduler.requests:
|
||||
assert engine_core.scheduler.requests[
|
||||
req_id].num_tokens == expected_num_tokens[req_id]
|
||||
assert (
|
||||
engine_core.scheduler.requests[req_id].num_tokens
|
||||
== expected_num_tokens[req_id]
|
||||
)
|
||||
expected_num_tokens[req_id] += 1
|
||||
req_id = (req_id + 1) % 2
|
||||
|
||||
@@ -391,17 +382,19 @@ def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
|
||||
def get_worker_cache_config_field(worker, key: str):
|
||||
return getattr(worker.cache_config, key)
|
||||
|
||||
num_gpu_blocks = engine_core.collective_rpc(
|
||||
get_worker_cache_config_field, args=("num_gpu_blocks", ))
|
||||
get_worker_cache_config_field, args=("num_gpu_blocks",)
|
||||
)
|
||||
num_cpu_blocks = engine_core.collective_rpc(
|
||||
get_worker_cache_config_field, args=("num_cpu_blocks", ))
|
||||
get_worker_cache_config_field, args=("num_cpu_blocks",)
|
||||
)
|
||||
assert all(x is not None for x in num_gpu_blocks)
|
||||
assert all(x is not None for x in num_cpu_blocks)
|
||||
|
||||
@@ -417,40 +410,35 @@ def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch):
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
|
||||
# Test with UUID object (common mistake)
|
||||
uuid_request = make_request()
|
||||
uuid_request.request_id = uuid.uuid4() # UUID object instead of string
|
||||
|
||||
with pytest.raises(TypeError,
|
||||
match="request_id must be a string, got.*UUID"):
|
||||
engine_core.add_request(
|
||||
*engine_core.preprocess_add_request(uuid_request))
|
||||
with pytest.raises(TypeError, match="request_id must be a string, got.*UUID"):
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(uuid_request))
|
||||
|
||||
# Test with integer
|
||||
int_request = make_request()
|
||||
int_request.request_id = 12345
|
||||
|
||||
with pytest.raises(TypeError,
|
||||
match="request_id must be a string, got.*int"):
|
||||
engine_core.add_request(
|
||||
*engine_core.preprocess_add_request(int_request))
|
||||
with pytest.raises(TypeError, match="request_id must be a string, got.*int"):
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(int_request))
|
||||
|
||||
# Test with None
|
||||
none_request = make_request()
|
||||
none_request.request_id = None
|
||||
|
||||
with pytest.raises(TypeError,
|
||||
match="request_id must be a string, got.*NoneType"):
|
||||
engine_core.add_request(
|
||||
*engine_core.preprocess_add_request(none_request))
|
||||
with pytest.raises(
|
||||
TypeError, match="request_id must be a string, got.*NoneType"
|
||||
):
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(none_request))
|
||||
|
||||
# Verify engine is still functional after errors
|
||||
valid_request = make_request()
|
||||
engine_core.add_request(
|
||||
*engine_core.preprocess_add_request(valid_request))
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(valid_request))
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
@@ -17,16 +17,14 @@ from transformers import AutoTokenizer
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm import SamplingParams
|
||||
from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
|
||||
ZmqEventPublisher)
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublisher
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||
SyncMPClient)
|
||||
from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
@@ -34,8 +32,7 @@ from ...distributed.conftest import MockSubscriber
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
@@ -44,8 +41,8 @@ PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||
|
||||
|
||||
def make_request(
|
||||
params: SamplingParams,
|
||||
prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest:
|
||||
params: SamplingParams, prompt_tokens_ids: Optional[list[int]] = None
|
||||
) -> EngineCoreRequest:
|
||||
if not prompt_tokens_ids:
|
||||
prompt_tokens_ids = PROMPT_TOKENS
|
||||
|
||||
@@ -64,7 +61,6 @@ def make_request(
|
||||
|
||||
|
||||
def loop_until_done(client: EngineCoreClient, outputs: dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = client.get_output().outputs
|
||||
|
||||
@@ -82,7 +78,6 @@ def loop_until_done(client: EngineCoreClient, outputs: dict):
|
||||
|
||||
|
||||
async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = (await client.get_output_async()).outputs
|
||||
|
||||
@@ -100,7 +95,6 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
|
||||
|
||||
|
||||
async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = (await client.get_output_async()).outputs
|
||||
|
||||
@@ -119,10 +113,9 @@ async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
|
||||
|
||||
|
||||
# Dummy utility function to monkey-patch into engine core.
|
||||
def echo(self,
|
||||
msg: str,
|
||||
err_msg: Optional[str] = None,
|
||||
sleep: Optional[float] = None) -> str:
|
||||
def echo(
|
||||
self, msg: str, err_msg: Optional[str] = None, sleep: Optional[float] = None
|
||||
) -> str:
|
||||
print(f"echo util function called: {msg}, {err_msg}")
|
||||
if sleep is not None:
|
||||
time.sleep(sleep)
|
||||
@@ -133,9 +126,9 @@ def echo(self,
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
|
||||
def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
||||
multiprocessing_mode: bool):
|
||||
|
||||
def test_engine_core_client(
|
||||
monkeypatch: pytest.MonkeyPatch, multiprocessing_mode: bool
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -143,8 +136,7 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
||||
m.setattr(EngineCore, "echo", echo, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
@@ -172,7 +164,8 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}"
|
||||
)
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
# Note: this code pathway will only work for multiprocessing
|
||||
@@ -191,10 +184,12 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
|
||||
)
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
|
||||
)
|
||||
"""Abort after request is finished."""
|
||||
|
||||
# Note: this code pathway will only work for multiprocessing
|
||||
@@ -202,7 +197,7 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
request = requests[0]
|
||||
client.add_request(request)
|
||||
time.sleep(10.)
|
||||
time.sleep(10.0)
|
||||
|
||||
client.abort_requests([request.request_id])
|
||||
|
||||
@@ -222,7 +217,6 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -231,7 +225,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
@@ -261,7 +256,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}"
|
||||
)
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
# Add requests to the engine.
|
||||
@@ -277,10 +273,12 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
|
||||
)
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
|
||||
)
|
||||
"""Utility method invocation"""
|
||||
|
||||
core_client: AsyncMPClient = client
|
||||
@@ -296,8 +294,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
# Test that cancelling the utility call doesn't destabilize the
|
||||
# engine.
|
||||
util_task = asyncio.create_task(
|
||||
core_client.call_utility_async("echo", "testarg2", None,
|
||||
0.5)) # sleep for 0.5 sec
|
||||
core_client.call_utility_async("echo", "testarg2", None, 0.5)
|
||||
) # sleep for 0.5 sec
|
||||
await asyncio.sleep(0.05)
|
||||
cancelled = util_task.cancel()
|
||||
assert cancelled
|
||||
@@ -305,9 +303,9 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
# Ensure client is still functional. The engine runs utility
|
||||
# methods in a single thread so this request won't be processed
|
||||
# until the cancelled sleeping one is complete.
|
||||
result = await asyncio.wait_for(core_client.call_utility_async(
|
||||
"echo", "testarg3"),
|
||||
timeout=1.0)
|
||||
result = await asyncio.wait_for(
|
||||
core_client.call_utility_async("echo", "testarg3"), timeout=1.0
|
||||
)
|
||||
assert result == "testarg3"
|
||||
finally:
|
||||
client.shutdown()
|
||||
@@ -353,8 +351,7 @@ def echo_dc_nested(
|
||||
msg: str,
|
||||
structure_type: str = "list_of_dicts",
|
||||
) -> Any:
|
||||
print(f"echo dc nested util function called: {msg}, "
|
||||
f"structure: {structure_type}")
|
||||
print(f"echo dc nested util function called: {msg}, structure: {structure_type}")
|
||||
val = None if msg is None else MyDataclass(msg)
|
||||
|
||||
if structure_type == "list_of_dicts": # noqa
|
||||
@@ -373,8 +370,8 @@ def echo_dc_nested(
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_custom_return(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -386,7 +383,8 @@ async def test_engine_core_client_util_method_custom_return(
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
@@ -402,22 +400,17 @@ async def test_engine_core_client_util_method_custom_return(
|
||||
# Test utility method returning custom / non-native data type.
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc", "testarg2", False)
|
||||
assert isinstance(result,
|
||||
MyDataclass) and result.message == "testarg2"
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc", "testarg2", True)
|
||||
result = await core_client.call_utility_async("echo_dc", "testarg2", False)
|
||||
assert isinstance(result, MyDataclass) and result.message == "testarg2"
|
||||
result = await core_client.call_utility_async("echo_dc", "testarg2", True)
|
||||
assert isinstance(result, list) and all(
|
||||
isinstance(r, MyDataclass) and r.message == "testarg2"
|
||||
for r in result)
|
||||
isinstance(r, MyDataclass) and r.message == "testarg2" for r in result
|
||||
)
|
||||
|
||||
# Test returning None and list of Nones
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc", None, False)
|
||||
result = await core_client.call_utility_async("echo_dc", None, False)
|
||||
assert result is None
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc", None, True)
|
||||
result = await core_client.call_utility_async("echo_dc", None, True)
|
||||
assert isinstance(result, list) and all(r is None for r in result)
|
||||
|
||||
finally:
|
||||
@@ -426,8 +419,8 @@ async def test_engine_core_client_util_method_custom_return(
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_custom_dict_return(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -439,7 +432,8 @@ async def test_engine_core_client_util_method_custom_dict_return(
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
@@ -457,22 +451,21 @@ async def test_engine_core_client_util_method_custom_dict_return(
|
||||
|
||||
# Test single object return
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_dict", "testarg3", False)
|
||||
assert isinstance(result,
|
||||
MyDataclass) and result.message == "testarg3"
|
||||
"echo_dc_dict", "testarg3", False
|
||||
)
|
||||
assert isinstance(result, MyDataclass) and result.message == "testarg3"
|
||||
|
||||
# Test dict return with custom value types
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_dict", "testarg3", True)
|
||||
"echo_dc_dict", "testarg3", True
|
||||
)
|
||||
assert isinstance(result, dict) and len(result) == 3
|
||||
for key, val in result.items():
|
||||
assert key in ["key1", "key2", "key3"]
|
||||
assert isinstance(val,
|
||||
MyDataclass) and val.message == "testarg3"
|
||||
assert isinstance(val, MyDataclass) and val.message == "testarg3"
|
||||
|
||||
# Test returning dict with None values
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_dict", None, True)
|
||||
result = await core_client.call_utility_async("echo_dc_dict", None, True)
|
||||
assert isinstance(result, dict) and len(result) == 3
|
||||
for key, val in result.items():
|
||||
assert key in ["key1", "key2", "key3"]
|
||||
@@ -484,8 +477,8 @@ async def test_engine_core_client_util_method_custom_dict_return(
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_nested_structures(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -497,7 +490,8 @@ async def test_engine_core_client_util_method_nested_structures(
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
@@ -514,42 +508,48 @@ async def test_engine_core_client_util_method_nested_structures(
|
||||
|
||||
# Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", "nested1", "list_of_dicts")
|
||||
"echo_dc_nested", "nested1", "list_of_dicts"
|
||||
)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
for i, item in enumerate(result):
|
||||
assert isinstance(item, dict)
|
||||
if i == 0:
|
||||
assert "a" in item and "b" in item
|
||||
assert isinstance(
|
||||
item["a"],
|
||||
MyDataclass) and item["a"].message == "nested1"
|
||||
assert isinstance(
|
||||
item["b"],
|
||||
MyDataclass) and item["b"].message == "nested1"
|
||||
assert (
|
||||
isinstance(item["a"], MyDataclass)
|
||||
and item["a"].message == "nested1"
|
||||
)
|
||||
assert (
|
||||
isinstance(item["b"], MyDataclass)
|
||||
and item["b"].message == "nested1"
|
||||
)
|
||||
else:
|
||||
assert "c" in item and "d" in item
|
||||
assert isinstance(
|
||||
item["c"],
|
||||
MyDataclass) and item["c"].message == "nested1"
|
||||
assert isinstance(
|
||||
item["d"],
|
||||
MyDataclass) and item["d"].message == "nested1"
|
||||
assert (
|
||||
isinstance(item["c"], MyDataclass)
|
||||
and item["c"].message == "nested1"
|
||||
)
|
||||
assert (
|
||||
isinstance(item["d"], MyDataclass)
|
||||
and item["d"].message == "nested1"
|
||||
)
|
||||
|
||||
# Test dict of lists: {"list1": [val, val], "list2": [val, val]}
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", "nested2", "dict_of_lists")
|
||||
"echo_dc_nested", "nested2", "dict_of_lists"
|
||||
)
|
||||
assert isinstance(result, dict) and len(result) == 2
|
||||
assert "list1" in result and "list2" in result
|
||||
for key, lst in result.items():
|
||||
assert isinstance(lst, list) and len(lst) == 2
|
||||
for item in lst:
|
||||
assert isinstance(
|
||||
item, MyDataclass) and item.message == "nested2"
|
||||
assert isinstance(item, MyDataclass) and item.message == "nested2"
|
||||
|
||||
# Test deeply nested: {"outer": [{"inner": [val, val]},
|
||||
# {"inner": [val]}]}
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", "nested3", "deep_nested")
|
||||
"echo_dc_nested", "nested3", "deep_nested"
|
||||
)
|
||||
assert isinstance(result, dict) and "outer" in result
|
||||
outer_list = result["outer"]
|
||||
assert isinstance(outer_list, list) and len(outer_list) == 2
|
||||
@@ -560,21 +560,22 @@ async def test_engine_core_client_util_method_nested_structures(
|
||||
inner_list1 = inner_dict1["inner"]
|
||||
assert isinstance(inner_list1, list) and len(inner_list1) == 2
|
||||
for item in inner_list1:
|
||||
assert isinstance(item,
|
||||
MyDataclass) and item.message == "nested3"
|
||||
assert isinstance(item, MyDataclass) and item.message == "nested3"
|
||||
|
||||
# Second dict in outer list should have "inner" with 1 item
|
||||
inner_dict2 = outer_list[1]
|
||||
assert isinstance(inner_dict2, dict) and "inner" in inner_dict2
|
||||
inner_list2 = inner_dict2["inner"]
|
||||
assert isinstance(inner_list2, list) and len(inner_list2) == 1
|
||||
assert isinstance(
|
||||
inner_list2[0],
|
||||
MyDataclass) and inner_list2[0].message == "nested3"
|
||||
assert (
|
||||
isinstance(inner_list2[0], MyDataclass)
|
||||
and inner_list2[0].message == "nested3"
|
||||
)
|
||||
|
||||
# Test with None values in nested structures
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", None, "list_of_dicts")
|
||||
"echo_dc_nested", None, "list_of_dicts"
|
||||
)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
for item in result:
|
||||
assert isinstance(item, dict)
|
||||
@@ -595,7 +596,6 @@ def test_kv_cache_events(
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
block_size = 16
|
||||
@@ -609,8 +609,7 @@ def test_kv_cache_events(
|
||||
)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
with set_default_torch_num_threads(1):
|
||||
@@ -622,9 +621,9 @@ def test_kv_cache_events(
|
||||
log_stats=False,
|
||||
)
|
||||
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
subscriber = MockSubscriber(endpoint,
|
||||
topic=publisher_config.topic,
|
||||
decode_type=KVEventBatch)
|
||||
subscriber = MockSubscriber(
|
||||
endpoint, topic=publisher_config.topic, decode_type=KVEventBatch
|
||||
)
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
@@ -641,22 +640,25 @@ def test_kv_cache_events(
|
||||
seq, received = result
|
||||
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert (len(received.events) == 1
|
||||
), "We should have exactly one BlockStored event"
|
||||
assert len(received.events) == 1, (
|
||||
"We should have exactly one BlockStored event"
|
||||
)
|
||||
event = received.events[0]
|
||||
assert isinstance(
|
||||
event, BlockStored), "We should have a BlockStored event"
|
||||
assert (len(event.block_hashes) == num_blocks
|
||||
), "We should have a BlockStored event with 2 block_hashes"
|
||||
assert (event.block_size == block_size
|
||||
), "Block size should be the same as the block size"
|
||||
assert (event.parent_block_hash
|
||||
is None), "Parent block hash should be None"
|
||||
assert isinstance(event, BlockStored), "We should have a BlockStored event"
|
||||
assert len(event.block_hashes) == num_blocks, (
|
||||
"We should have a BlockStored event with 2 block_hashes"
|
||||
)
|
||||
assert event.block_size == block_size, (
|
||||
"Block size should be the same as the block size"
|
||||
)
|
||||
assert event.parent_block_hash is None, "Parent block hash should be None"
|
||||
assert event.lora_id is None, "Lora id should be None"
|
||||
assert (len(event.token_ids) == num_blocks * block_size
|
||||
), "Token ids should be the same as the custom tokens"
|
||||
assert (event.token_ids == custom_tokens
|
||||
), "Token ids should be the same as the custom tokens"
|
||||
assert len(event.token_ids) == num_blocks * block_size, (
|
||||
"Token ids should be the same as the custom tokens"
|
||||
)
|
||||
assert event.token_ids == custom_tokens, (
|
||||
"Token ids should be the same as the custom tokens"
|
||||
)
|
||||
finally:
|
||||
client.shutdown()
|
||||
subscriber.close()
|
||||
@@ -674,7 +676,6 @@ async def test_kv_cache_events_dp(
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
block_size = 16
|
||||
@@ -692,8 +693,7 @@ async def test_kv_cache_events_dp(
|
||||
)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
with set_default_torch_num_threads(1):
|
||||
@@ -710,13 +710,12 @@ async def test_kv_cache_events_dp(
|
||||
base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
endpoints = []
|
||||
for i in range(dp_size):
|
||||
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(
|
||||
base_endpoint, i)
|
||||
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(base_endpoint, i)
|
||||
endpoints.append(offset_endpoint)
|
||||
|
||||
subscriber = MockSubscriber(endpoints,
|
||||
topic=publisher_config.topic,
|
||||
decode_type=KVEventBatch)
|
||||
subscriber = MockSubscriber(
|
||||
endpoints, topic=publisher_config.topic, decode_type=KVEventBatch
|
||||
)
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
@@ -734,15 +733,12 @@ async def test_kv_cache_events_dp(
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Initialize outputs dict for all requests
|
||||
outputs: dict[str, list] = {
|
||||
req_id: []
|
||||
for req_id in all_request_ids
|
||||
}
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in all_request_ids}
|
||||
|
||||
print("processing requests...")
|
||||
await asyncio.wait_for(loop_until_fully_done_async(
|
||||
client, outputs),
|
||||
timeout=20.0)
|
||||
await asyncio.wait_for(
|
||||
loop_until_fully_done_async(client, outputs), timeout=20.0
|
||||
)
|
||||
|
||||
# Receive from subscriber until no more messages
|
||||
print("collecting results...")
|
||||
@@ -755,13 +751,11 @@ async def test_kv_cache_events_dp(
|
||||
results.append(result)
|
||||
|
||||
# Collect all events and data_parallel_ranks from all results
|
||||
all_dp_ranks = [
|
||||
received.data_parallel_rank for (_, received) in results
|
||||
]
|
||||
all_dp_ranks = [received.data_parallel_rank for (_, received) in results]
|
||||
unique_dps = set(all_dp_ranks)
|
||||
assert (
|
||||
len(unique_dps) == 2
|
||||
), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
|
||||
assert len(unique_dps) == 2, (
|
||||
f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
@@ -770,7 +764,6 @@ async def test_kv_cache_events_dp(
|
||||
|
||||
@pytest.mark.timeout(20)
|
||||
def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with monkeypatch.context() as m, pytest.raises(Exception) as e_info:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -787,7 +780,8 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
|
||||
t = time.time()
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
print(f"VllmConfig creation took {time.time() - t:.2f} seconds.")
|
||||
|
||||
@@ -815,8 +809,7 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_engine_core_proc_instantiation_cuda_empty(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_engine_core_proc_instantiation_cuda_empty(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES
|
||||
is empty. This ensures the engine frontend does not need access to GPUs.
|
||||
@@ -833,17 +826,13 @@ def test_engine_core_proc_instantiation_cuda_empty(
|
||||
|
||||
# Only implement the methods that are actually called during init
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
mock_spec = FullAttentionSpec(block_size=16,
|
||||
num_kv_heads=1,
|
||||
head_size=64,
|
||||
dtype=torch.float16)
|
||||
|
||||
mock_executor.get_kv_cache_specs.return_value = [{
|
||||
"default": mock_spec
|
||||
}]
|
||||
mock_executor.determine_available_memory.return_value = [
|
||||
1024 * 1024 * 1024
|
||||
]
|
||||
mock_spec = FullAttentionSpec(
|
||||
block_size=16, num_kv_heads=1, head_size=64, dtype=torch.float16
|
||||
)
|
||||
|
||||
mock_executor.get_kv_cache_specs.return_value = [{"default": mock_spec}]
|
||||
mock_executor.determine_available_memory.return_value = [1024 * 1024 * 1024]
|
||||
mock_executor.initialize_from_config.return_value = None
|
||||
mock_executor.max_concurrent_batches = 1
|
||||
|
||||
@@ -857,19 +846,22 @@ def test_engine_core_proc_instantiation_cuda_empty(
|
||||
|
||||
from vllm.v1.engine.utils import EngineZmqAddresses
|
||||
|
||||
def mock_startup_handshake(self, handshake_socket, local_client,
|
||||
headless, parallel_config):
|
||||
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
|
||||
outputs=["tcp://127.0.0.1:5556"],
|
||||
coordinator_input=None,
|
||||
coordinator_output=None)
|
||||
def mock_startup_handshake(
|
||||
self, handshake_socket, local_client, headless, parallel_config
|
||||
):
|
||||
return EngineZmqAddresses(
|
||||
inputs=["tcp://127.0.0.1:5555"],
|
||||
outputs=["tcp://127.0.0.1:5556"],
|
||||
coordinator_input=None,
|
||||
coordinator_output=None,
|
||||
)
|
||||
|
||||
# Background processes are not important here
|
||||
m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake)
|
||||
|
||||
vllm_config = EngineArgs(
|
||||
model="deepseek-ai/DeepSeek-V2-Lite",
|
||||
trust_remote_code=True).create_engine_config()
|
||||
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||
).create_engine_config()
|
||||
engine_core_proc = EngineCoreProc(
|
||||
vllm_config=vllm_config,
|
||||
local_client=True,
|
||||
|
||||
@@ -40,23 +40,139 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
||||
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
|
||||
|
||||
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \
|
||||
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", (
|
||||
"Should use FastIncrementalDetokenizer by default"
|
||||
)
|
||||
|
||||
# Process tokens incrementally
|
||||
test_tokens = [
|
||||
236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908,
|
||||
147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292,
|
||||
827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418,
|
||||
569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118,
|
||||
35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140,
|
||||
236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654,
|
||||
236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654,
|
||||
236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817,
|
||||
4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509,
|
||||
19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398,
|
||||
432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745,
|
||||
2555, 513, 236789, 602, 31118, 569
|
||||
236840,
|
||||
107,
|
||||
138,
|
||||
236782,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
6265,
|
||||
1083,
|
||||
623,
|
||||
121908,
|
||||
147418,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
6265,
|
||||
236779,
|
||||
2084,
|
||||
1083,
|
||||
623,
|
||||
203292,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
6265,
|
||||
236779,
|
||||
7777,
|
||||
1083,
|
||||
623,
|
||||
121908,
|
||||
147418,
|
||||
569,
|
||||
537,
|
||||
236789,
|
||||
65880,
|
||||
569,
|
||||
537,
|
||||
236789,
|
||||
62580,
|
||||
853,
|
||||
115693,
|
||||
210118,
|
||||
35178,
|
||||
16055,
|
||||
1270,
|
||||
759,
|
||||
215817,
|
||||
4758,
|
||||
1925,
|
||||
1117,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
5654,
|
||||
1083,
|
||||
623,
|
||||
110733,
|
||||
46291,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
5654,
|
||||
236779,
|
||||
2084,
|
||||
1083,
|
||||
623,
|
||||
136955,
|
||||
56731,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
5654,
|
||||
236779,
|
||||
7777,
|
||||
1083,
|
||||
623,
|
||||
194776,
|
||||
2947,
|
||||
496,
|
||||
109811,
|
||||
1608,
|
||||
890,
|
||||
215817,
|
||||
4758,
|
||||
1925,
|
||||
1117,
|
||||
2789,
|
||||
432,
|
||||
398,
|
||||
602,
|
||||
31118,
|
||||
569,
|
||||
124866,
|
||||
134772,
|
||||
509,
|
||||
19478,
|
||||
1640,
|
||||
33779,
|
||||
236743,
|
||||
236770,
|
||||
236819,
|
||||
236825,
|
||||
236771,
|
||||
432,
|
||||
398,
|
||||
432,
|
||||
237167,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
77984,
|
||||
1083,
|
||||
623,
|
||||
2709,
|
||||
236745,
|
||||
2555,
|
||||
513,
|
||||
236789,
|
||||
602,
|
||||
31118,
|
||||
569,
|
||||
]
|
||||
|
||||
output = ""
|
||||
@@ -66,8 +182,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
||||
finished = i == len(test_tokens) - 1
|
||||
output += detokenizer.get_next_output_text(finished, delta=True)
|
||||
|
||||
|
||||
# fmt: off
|
||||
# fmt: off
|
||||
assert output == r'''[
|
||||
{
|
||||
"source": "Résultats",
|
||||
|
||||
@@ -43,7 +43,8 @@ def _vllm_model(
|
||||
# env var adjustment via monkeypatch
|
||||
scope="function",
|
||||
# Prefix caching
|
||||
params=[False, True])
|
||||
params=[False, True],
|
||||
)
|
||||
def vllm_model(vllm_runner, request, monkeypatch):
|
||||
"""VllmRunner test fixture parameterized by APC True/False."""
|
||||
with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model:
|
||||
@@ -62,14 +63,15 @@ def vllm_model_apc(vllm_runner, monkeypatch):
|
||||
# env var adjustment via monkeypatch
|
||||
scope="function",
|
||||
# Prefix caching
|
||||
params=[False, True])
|
||||
params=[False, True],
|
||||
)
|
||||
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
|
||||
"""VllmRunner test fixture with APC."""
|
||||
with _vllm_model(
|
||||
request.param,
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
skip_tokenizer_init=True,
|
||||
request.param,
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
skip_tokenizer_init=True,
|
||||
) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
@@ -97,9 +99,11 @@ def _get_test_sampling_params(
|
||||
top_p=0.95,
|
||||
n=n,
|
||||
seed=seed,
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
regex="[0-9]+") if structured_outputs else None,
|
||||
) for n in n_list
|
||||
structured_outputs=StructuredOutputsParams(regex="[0-9]+")
|
||||
if structured_outputs
|
||||
else None,
|
||||
)
|
||||
for n in n_list
|
||||
], n_list
|
||||
|
||||
|
||||
@@ -132,23 +136,20 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
|
||||
for out, n in zip(outputs, n_list):
|
||||
completion_counts: dict[str, int] = {}
|
||||
# Assert correct number of completions
|
||||
assert len(out.outputs) == n, (
|
||||
f"{len(out.outputs)} completions; {n} expected.")
|
||||
assert len(out.outputs) == n, f"{len(out.outputs)} completions; {n} expected."
|
||||
for idx in range(n):
|
||||
comp = out.outputs[idx]
|
||||
# Assert correct completion indices
|
||||
assert comp.index == idx, (f"Index {comp.index}; expected {idx}.")
|
||||
assert comp.index == idx, f"Index {comp.index}; expected {idx}."
|
||||
text = comp.text
|
||||
completion_counts[text] = completion_counts.get(text, 0) + 1
|
||||
# Assert unique completions
|
||||
if len(completion_counts) != n:
|
||||
repeats = {
|
||||
txt: num
|
||||
for (txt, num) in completion_counts.items() if num > 1
|
||||
}
|
||||
repeats = {txt: num for (txt, num) in completion_counts.items() if num > 1}
|
||||
raise AssertionError(
|
||||
f"{len(completion_counts)} unique completions; expected"
|
||||
f" {n}. Repeats: {repeats}")
|
||||
f" {n}. Repeats: {repeats}"
|
||||
)
|
||||
|
||||
|
||||
def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
|
||||
@@ -162,13 +163,12 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
|
||||
}
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
with vllm_runner(
|
||||
MODEL,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
MODEL,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
) as vllm_model:
|
||||
llm: LLM = vllm_model.llm
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = llm.generate(example_prompts, sampling_params)
|
||||
|
||||
n_prompts = len(example_prompts)
|
||||
@@ -192,15 +192,14 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
|
||||
num_requests_running = find_metric("vllm:num_requests_running")
|
||||
assert len(num_requests_running) == 1
|
||||
assert isinstance(num_requests_running[0], Gauge)
|
||||
assert num_requests_running[0].value == .0
|
||||
assert num_requests_running[0].value == 0.0
|
||||
|
||||
generation_tokens = find_metric("vllm:generation_tokens")
|
||||
assert len(generation_tokens) == 1
|
||||
assert isinstance(generation_tokens[0], Counter)
|
||||
assert generation_tokens[0].value == total_tokens
|
||||
|
||||
request_generation_tokens = find_metric(
|
||||
"vllm:request_generation_tokens")
|
||||
request_generation_tokens = find_metric("vllm:request_generation_tokens")
|
||||
assert len(request_generation_tokens) == 1
|
||||
assert isinstance(request_generation_tokens[0], Histogram)
|
||||
assert "+Inf" in request_generation_tokens[0].buckets
|
||||
@@ -209,15 +208,15 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
|
||||
assert request_generation_tokens[0].sum == total_tokens
|
||||
|
||||
num_accepted_tokens_per_pos = find_metric(
|
||||
"vllm:spec_decode_num_accepted_tokens_per_pos")
|
||||
"vllm:spec_decode_num_accepted_tokens_per_pos"
|
||||
)
|
||||
assert len(num_accepted_tokens_per_pos) == 1
|
||||
assert isinstance(num_accepted_tokens_per_pos[0], Vector)
|
||||
assert len(num_accepted_tokens_per_pos[0].values) == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"])
|
||||
def test_skip_tokenizer_initialization(model: str,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_skip_tokenizer_initialization(model: str, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||
@@ -232,8 +231,9 @@ def test_skip_tokenizer_initialization(model: str,
|
||||
with pytest.raises(ValueError, match="cannot pass text prompts when"):
|
||||
llm.generate("abc", sampling_params)
|
||||
|
||||
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
|
||||
sampling_params=sampling_params)
|
||||
outputs = llm.generate(
|
||||
{"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params
|
||||
)
|
||||
assert len(outputs) > 0
|
||||
completions = outputs[0].outputs
|
||||
assert len(completions) > 0
|
||||
|
||||
@@ -7,19 +7,20 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
STOP_STRINGS,
|
||||
DummyOutputProcessorTestVectors,
|
||||
MockEngineCore)
|
||||
from tests.v1.engine.utils import (
|
||||
NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
STOP_STRINGS,
|
||||
DummyOutputProcessorTestVectors,
|
||||
MockEngineCore,
|
||||
)
|
||||
from vllm import PoolingParams
|
||||
from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
RequestOutputCollector)
|
||||
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
@@ -40,33 +41,34 @@ def _ref_convert_id_to_token(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_output_kind",
|
||||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens)
|
||||
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
def test_incremental_detokenization(
|
||||
request_output_kind: RequestOutputKind, dummy_test_vectors
|
||||
):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
EngineCoreRequest(request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
),
|
||||
pooling_params=None)
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
),
|
||||
pooling_params=None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@@ -102,8 +104,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
|
||||
# Confirmed tracked values matches what we expected.
|
||||
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
|
||||
zip(dummy_test_vectors.generation_strings,
|
||||
dummy_test_vectors.generation_tokens)):
|
||||
zip(dummy_test_vectors.generation_strings, dummy_test_vectors.generation_tokens)
|
||||
):
|
||||
gen_str = gen_strings[f"request-{idx}"]
|
||||
gen_toks = gen_tokens[f"request-{idx}"]
|
||||
|
||||
@@ -134,9 +136,11 @@ def _validate_logprobs(
|
||||
ref_prompt_logprobs = dtv.prompt_logprobs[req_idx]
|
||||
if num_sample_logprobs is not None:
|
||||
# Validate sample logprobs
|
||||
assert logprobs is not None, (f"Request {req_id} requires sample"
|
||||
" logprobs but sample logprobs are"
|
||||
" None.")
|
||||
assert logprobs is not None, (
|
||||
f"Request {req_id} requires sample"
|
||||
" logprobs but sample logprobs are"
|
||||
" None."
|
||||
)
|
||||
# Require num sampled tokens to match num
|
||||
# sampled logprobs - especially important
|
||||
# to check since the detokenizer can cause
|
||||
@@ -147,44 +151,51 @@ def _validate_logprobs(
|
||||
assert num_new_tokens == len_sample_logprobs, (
|
||||
f"Request {req_id} has {num_new_tokens}"
|
||||
" completion tokens but has"
|
||||
f" {len_sample_logprobs} sample logprobs.")
|
||||
f" {len_sample_logprobs} sample logprobs."
|
||||
)
|
||||
ref_cumulative_logprob = 0.0
|
||||
for idx, (sampled_token,
|
||||
pos_logprob_dict) in enumerate(zip(new_tokens,
|
||||
logprobs)):
|
||||
for idx, (sampled_token, pos_logprob_dict) in enumerate(
|
||||
zip(new_tokens, logprobs)
|
||||
):
|
||||
# Break out the reference log probability value &
|
||||
# logprob token id tensors associated with this
|
||||
# position in the completion. Also break out the
|
||||
# sampled token ranks
|
||||
(ref_pos_logprob_toks, ref_pos_logprob_vals,
|
||||
ref_sampled_token_rank) = ref_logprobs[idx]
|
||||
(ref_pos_logprob_toks, ref_pos_logprob_vals, ref_sampled_token_rank) = (
|
||||
ref_logprobs[idx]
|
||||
)
|
||||
# For each position in the completion sequence,
|
||||
# ensure the actual sampled token is among the
|
||||
# logprobs
|
||||
assert sampled_token in pos_logprob_dict, (
|
||||
f"Sampled token {sampled_token} not"
|
||||
f" present in logprob at index {idx}")
|
||||
f" present in logprob at index {idx}"
|
||||
)
|
||||
|
||||
# Validate number of sample logprobs
|
||||
num_lp_toks = len(pos_logprob_dict)
|
||||
assert (num_lp_toks == num_sample_logprobs
|
||||
or num_lp_toks == num_sample_logprobs +
|
||||
1), ("Valid numbers of sample logprobs are"
|
||||
f" {num_sample_logprobs} or"
|
||||
f" {num_sample_logprobs+1} but"
|
||||
f" {num_lp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}")
|
||||
assert (
|
||||
num_lp_toks == num_sample_logprobs
|
||||
or num_lp_toks == num_sample_logprobs + 1
|
||||
), (
|
||||
"Valid numbers of sample logprobs are"
|
||||
f" {num_sample_logprobs} or"
|
||||
f" {num_sample_logprobs + 1} but"
|
||||
f" {num_lp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate sampled token logprob rank
|
||||
smp_lp = pos_logprob_dict[sampled_token]
|
||||
smp_lp_rank = smp_lp.rank
|
||||
assert (ref_sampled_token_rank == smp_lp_rank), (
|
||||
assert ref_sampled_token_rank == smp_lp_rank, (
|
||||
"Sampled token logprob rank"
|
||||
f" {smp_lp_rank} does not match"
|
||||
" correct value"
|
||||
f" {ref_sampled_token_rank}"
|
||||
f" in Logprob {smp_lp}")
|
||||
f" in Logprob {smp_lp}"
|
||||
)
|
||||
|
||||
# Validate that the logprob processor yields
|
||||
# the correct log probabilities and valid
|
||||
@@ -198,7 +209,8 @@ def _validate_logprobs(
|
||||
ref_tok_id = ref_pos_logprob_toks[jdx]
|
||||
assert ref_tok_id in pos_logprob_dict, (
|
||||
f"Expected token {ref_tok_id} to be"
|
||||
f" in logprob dict but it is not.")
|
||||
f" in logprob dict but it is not."
|
||||
)
|
||||
|
||||
# Extract actually-generated logprob
|
||||
# info
|
||||
@@ -208,40 +220,43 @@ def _validate_logprobs(
|
||||
|
||||
# A "top" (rank 1) logprob must be
|
||||
# present
|
||||
rank_one_appears = (True
|
||||
if lp_rank == 1 else rank_one_appears)
|
||||
rank_one_appears = True if lp_rank == 1 else rank_one_appears
|
||||
|
||||
# Rank must be >= 1
|
||||
assert lp_rank >= 1, (f"Logprob {lp} has invalid"
|
||||
f" rank {lp_rank} < 1."
|
||||
f" Logprob dict: {pos_logprob_dict}")
|
||||
assert lp_rank >= 1, (
|
||||
f"Logprob {lp} has invalid"
|
||||
f" rank {lp_rank} < 1."
|
||||
f" Logprob dict: {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate log probability
|
||||
assert math.isclose(lp_val, ref_lp_val), (
|
||||
f"Token id {ref_tok_id} appears in logprobs dict"
|
||||
f" at position {idx} in completion with log"
|
||||
f" probability {lp_val} but {ref_lp_val} was"
|
||||
f" expected. Logprob: {lp}")
|
||||
f" expected. Logprob: {lp}"
|
||||
)
|
||||
|
||||
assert rank_one_appears, (f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}")
|
||||
assert rank_one_appears, (
|
||||
f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate logprobs detokenization
|
||||
for lp_tok in pos_logprob_dict:
|
||||
# Confirm that sample logprob decoded token matches
|
||||
# the logprob token id at this sequence position
|
||||
decoded_token = pos_logprob_dict[lp_tok].decoded_token
|
||||
ref_decoded_token = _ref_convert_id_to_token(
|
||||
dtv.tokenizer, lp_tok)
|
||||
ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, lp_tok)
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Sampled logprob token id {lp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})")
|
||||
f" (at position {idx})"
|
||||
)
|
||||
|
||||
ref_cumulative_logprob += pos_logprob_dict[
|
||||
sampled_token].logprob
|
||||
ref_cumulative_logprob += pos_logprob_dict[sampled_token].logprob
|
||||
# Assert that cumulative logprobs are correct
|
||||
assert math.isclose(cumulative_logprob, ref_cumulative_logprob)
|
||||
else:
|
||||
@@ -254,7 +269,8 @@ def _validate_logprobs(
|
||||
assert prompt_logprobs is not None, (
|
||||
f"Request {req_id} requires prompt"
|
||||
" logprobs but prompt logprobs are"
|
||||
" None.")
|
||||
" None."
|
||||
)
|
||||
# Require num prompt tokens to match num
|
||||
# prompt logprobs
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
@@ -262,56 +278,70 @@ def _validate_logprobs(
|
||||
assert num_prompt_tokens == len_prompt_logprobs, (
|
||||
f"Request {req_id} has {num_prompt_tokens}"
|
||||
" prompt tokens but has"
|
||||
f" {len_prompt_logprobs} prompt logprobs.")
|
||||
f" {len_prompt_logprobs} prompt logprobs."
|
||||
)
|
||||
# First prompt logprob is None
|
||||
first_plp_dict = prompt_logprobs[0]
|
||||
assert first_plp_dict is None, (
|
||||
f"Request {req_id} first prompt logprob"
|
||||
f" should be None but has following value"
|
||||
f" instead: {first_plp_dict}")
|
||||
f" instead: {first_plp_dict}"
|
||||
)
|
||||
# Break out the reference prompt log prob value &
|
||||
# logprob token id matrices for the whole prompt.
|
||||
# Also break out the prompt token rank vector
|
||||
(ref_prompt_logprob_toks, ref_prompt_logprob_vals,
|
||||
ref_prompt_token_ranks) = ref_prompt_logprobs
|
||||
(
|
||||
ref_prompt_logprob_toks,
|
||||
ref_prompt_logprob_vals,
|
||||
ref_prompt_token_ranks,
|
||||
) = ref_prompt_logprobs
|
||||
for idx, (prompt_token, pos_logprob_dict) in enumerate(
|
||||
zip(prompt_token_ids[1:], prompt_logprobs[1:])):
|
||||
|
||||
zip(prompt_token_ids[1:], prompt_logprobs[1:])
|
||||
):
|
||||
# Break out the reference prompt log prob value
|
||||
# vector, prompt logprob token id vector, and
|
||||
# prompt token rank at the current position.
|
||||
(ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals,
|
||||
ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :],
|
||||
ref_prompt_logprob_vals[idx, :],
|
||||
ref_prompt_token_ranks[idx])
|
||||
(
|
||||
ref_pos_prompt_logprob_toks,
|
||||
ref_pos_prompt_logprob_vals,
|
||||
ref_pos_prompt_token_rank,
|
||||
) = (
|
||||
ref_prompt_logprob_toks[idx, :],
|
||||
ref_prompt_logprob_vals[idx, :],
|
||||
ref_prompt_token_ranks[idx],
|
||||
)
|
||||
|
||||
# For each position in the prompt sequence,
|
||||
# ensure the actual prompt token is among the
|
||||
# logprobs
|
||||
assert prompt_token in pos_logprob_dict, (
|
||||
f"Prompt token {prompt_token} not"
|
||||
f" present in logprob at index {idx}")
|
||||
f"Prompt token {prompt_token} not present in logprob at index {idx}"
|
||||
)
|
||||
# Validate number of prompt logprobs
|
||||
num_plp_toks = len(pos_logprob_dict)
|
||||
assert (num_plp_toks == num_prompt_logprobs
|
||||
or num_plp_toks == num_prompt_logprobs +
|
||||
1), ("Valid numbers of prompt logprobs are"
|
||||
f" {num_prompt_logprobs} or"
|
||||
f" {num_prompt_logprobs+1} but"
|
||||
f" {num_plp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}")
|
||||
assert (
|
||||
num_plp_toks == num_prompt_logprobs
|
||||
or num_plp_toks == num_prompt_logprobs + 1
|
||||
), (
|
||||
"Valid numbers of prompt logprobs are"
|
||||
f" {num_prompt_logprobs} or"
|
||||
f" {num_prompt_logprobs + 1} but"
|
||||
f" {num_plp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate prompt token logprob rank
|
||||
prmpt_tok_lp = pos_logprob_dict[prompt_token]
|
||||
prmpt_tok_lp_rank = prmpt_tok_lp.rank
|
||||
ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank
|
||||
assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), (
|
||||
assert ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank, (
|
||||
"Prompt token logprob rank"
|
||||
f" {prmpt_tok_lp_rank} does not match"
|
||||
" correct value"
|
||||
f" {ref_prmpt_tok_lp_rank}"
|
||||
f" in Logprob {prmpt_tok_lp}")
|
||||
f" in Logprob {prmpt_tok_lp}"
|
||||
)
|
||||
|
||||
# Validate that the logprob processor yields
|
||||
# the correct prompt log probs and valid
|
||||
@@ -325,7 +355,8 @@ def _validate_logprobs(
|
||||
ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx])
|
||||
assert ref_tok_id in pos_logprob_dict, (
|
||||
f"Expected token {ref_tok_id} to be"
|
||||
f" in logprob dict but it is not.")
|
||||
f" in logprob dict but it is not."
|
||||
)
|
||||
|
||||
# Extract actually-generated logprob
|
||||
# info
|
||||
@@ -335,87 +366,93 @@ def _validate_logprobs(
|
||||
|
||||
# A "top" (rank 1) logprob must be
|
||||
# present
|
||||
rank_one_appears = (True
|
||||
if plp_rank == 1 else rank_one_appears)
|
||||
rank_one_appears = True if plp_rank == 1 else rank_one_appears
|
||||
|
||||
# Rank must be >= 1
|
||||
assert plp_rank >= 1, (
|
||||
f"Logprob {plp} has invalid"
|
||||
f" rank {plp_rank} < 1."
|
||||
f" Logprob dict: {pos_logprob_dict}")
|
||||
f" Logprob dict: {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate log probability
|
||||
assert math.isclose(plp_val, ref_plp_val), (
|
||||
f"Token id {ref_tok_id} appears in logprobs dict"
|
||||
f" at position {idx} in completion with log"
|
||||
f" probability {plp_val} but {ref_plp_val} was"
|
||||
f" expected. Logprob: {plp}")
|
||||
f" expected. Logprob: {plp}"
|
||||
)
|
||||
|
||||
assert rank_one_appears, (f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}")
|
||||
assert rank_one_appears, (
|
||||
f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate prompt logprob detokenization
|
||||
for plp_tok in pos_logprob_dict:
|
||||
# Confirm that prompt logprob decoded token matches
|
||||
# the logprob token id at this sequence position
|
||||
decoded_token = pos_logprob_dict[plp_tok].decoded_token
|
||||
ref_decoded_token = _ref_convert_id_to_token(
|
||||
dtv.tokenizer, plp_tok)
|
||||
ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, plp_tok)
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Prompt logprob token id {plp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})")
|
||||
f" (at position {idx})"
|
||||
)
|
||||
else:
|
||||
# Prompt logprobs disabled for this request
|
||||
assert prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_output_kind",
|
||||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.parametrize("num_sample_logprobs",
|
||||
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
@pytest.mark.parametrize("num_prompt_logprobs",
|
||||
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
|
||||
def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
num_sample_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int],
|
||||
dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
@pytest.mark.parametrize("num_prompt_logprobs", [None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
|
||||
def test_logprobs_processor(
|
||||
request_output_kind: RequestOutputKind,
|
||||
num_sample_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int],
|
||||
dummy_test_vectors,
|
||||
):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=None if num_sample_logprobs is None else
|
||||
dummy_test_vectors.generation_logprobs,
|
||||
generated_logprobs_raw=None
|
||||
if num_sample_logprobs is None
|
||||
else dummy_test_vectors.generation_logprobs,
|
||||
prompt_logprobs_raw=None
|
||||
if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs)
|
||||
if num_prompt_logprobs is None
|
||||
else dummy_test_vectors.prompt_logprobs,
|
||||
)
|
||||
|
||||
# Make N requests.
|
||||
request_id_list = [
|
||||
f"request-{idx}"
|
||||
for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
]
|
||||
requests = [
|
||||
EngineCoreRequest(request_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
),
|
||||
pooling_params=None)
|
||||
EngineCoreRequest(
|
||||
request_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
),
|
||||
pooling_params=None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@@ -446,7 +483,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
prompt_logprobs = request_output.prompt_logprobs
|
||||
logprobs = request_output.outputs[0].logprobs
|
||||
gen_cumulative_logprobs[request_id] = request_output.outputs[
|
||||
0].cumulative_logprob
|
||||
0
|
||||
].cumulative_logprob
|
||||
if request_id not in gen_logprobs:
|
||||
# Start tracking sample and prompt logprobs for this request
|
||||
gen_tokens[request_id] = new_tokens
|
||||
@@ -463,10 +501,16 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
plp.extend(prompt_logprobs)
|
||||
|
||||
# Confirmed tracked logprobs match what we expect
|
||||
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs, dummy_test_vectors,
|
||||
request_id_list, num_sample_logprobs,
|
||||
num_prompt_logprobs)
|
||||
_validate_logprobs(
|
||||
gen_tokens,
|
||||
gen_logprobs,
|
||||
gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs,
|
||||
dummy_test_vectors,
|
||||
request_id_list,
|
||||
num_sample_logprobs,
|
||||
num_prompt_logprobs,
|
||||
)
|
||||
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
assert not output_processor.has_unfinished_requests()
|
||||
@@ -474,15 +518,23 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs",
|
||||
[(False, "stop_token_ids", False, None),
|
||||
(True, "stop_token_ids", False, None),
|
||||
(False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||
(True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||
(False, "eos_token_id", False, None), (True, "eos_token_id", False, None),
|
||||
(False, "eos_token_id", True, None)])
|
||||
def test_stop_token(include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int], stop_token_type: str,
|
||||
ignore_eos: bool, dummy_test_vectors):
|
||||
[
|
||||
(False, "stop_token_ids", False, None),
|
||||
(True, "stop_token_ids", False, None),
|
||||
(False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||
(True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||
(False, "eos_token_id", False, None),
|
||||
(True, "eos_token_id", False, None),
|
||||
(False, "eos_token_id", True, None),
|
||||
],
|
||||
)
|
||||
def test_stop_token(
|
||||
include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int],
|
||||
stop_token_type: str,
|
||||
ignore_eos: bool,
|
||||
dummy_test_vectors,
|
||||
):
|
||||
"""Test output processor EOS/stop token handling.
|
||||
|
||||
Send mock engine core request to mock engine core and pass core outputs
|
||||
@@ -523,9 +575,10 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
dummy_test_vectors: dummy engine core outputs and other data structures
|
||||
"""
|
||||
model_id = dummy_test_vectors.tokenizer.name_or_path
|
||||
if model_id != 'meta-llama/Llama-3.2-1B':
|
||||
raise AssertionError("Test requires meta-llama/Llama-3.2-1B but "
|
||||
f"{model_id} is in use.")
|
||||
if model_id != "meta-llama/Llama-3.2-1B":
|
||||
raise AssertionError(
|
||||
f"Test requires meta-llama/Llama-3.2-1B but {model_id} is in use."
|
||||
)
|
||||
do_logprobs = num_sample_logprobs is not None
|
||||
# EOS under test; if False, stop_token_ids under test
|
||||
is_eos_test = stop_token_type == "eos_token_id"
|
||||
@@ -536,18 +589,16 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
) # '<|end_of_text|>'
|
||||
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
|
||||
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
# Dummy engine core outputs, with control tokens suffixed to test stops
|
||||
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
|
||||
suffix_token = [eos_token_id] if is_eos_test else stop_token_ids
|
||||
assert suffix_token is not None and isinstance(suffix_token[0], int)
|
||||
generation_string = dummy_test_vectors.generation_strings[0]
|
||||
generation_tokens = (dummy_test_vectors.generation_tokens[0] +
|
||||
2 * suffix_token)
|
||||
generation_tokens = dummy_test_vectors.generation_tokens[0] + 2 * suffix_token
|
||||
if do_logprobs:
|
||||
generation_logprobs = (
|
||||
dummy_test_vectors.generation_logprobs[0] +
|
||||
2 * [dummy_test_vectors.generation_logprobs[0][-1]])
|
||||
generation_logprobs = dummy_test_vectors.generation_logprobs[0] + 2 * [
|
||||
dummy_test_vectors.generation_logprobs[0][-1]
|
||||
]
|
||||
prompt_string = dummy_test_vectors.prompt_strings[0]
|
||||
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
|
||||
engine_core = MockEngineCore(
|
||||
@@ -556,7 +607,8 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
prompt_logprobs_raw=None,
|
||||
eos_token_id=eos_token_id,
|
||||
stop_token_ids=stop_token_ids,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
|
||||
# Make request.
|
||||
request_id = "request-0"
|
||||
@@ -580,7 +632,8 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
prompt_logprobs=None,
|
||||
ignore_eos=ignore_eos,
|
||||
),
|
||||
pooling_params=None)
|
||||
pooling_params=None,
|
||||
)
|
||||
|
||||
# Add request to the detokenizer.
|
||||
output_processor.add_request(request, prompt_string)
|
||||
@@ -605,7 +658,7 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
# Update tracking.
|
||||
request_output = request_outputs[0]
|
||||
if request_output.finished:
|
||||
finish_reason = ("length" if is_eos_ignore_test else "stop")
|
||||
finish_reason = "length" if is_eos_ignore_test else "stop"
|
||||
assert request_output.outputs[0].finish_reason == finish_reason
|
||||
|
||||
gen_string += request_output.outputs[0].text
|
||||
@@ -614,7 +667,7 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
gen_logprobs.extend(request_output.outputs[0].logprobs)
|
||||
|
||||
# Validate generated text
|
||||
control_token = '<|end_of_text|>' if is_eos_test else '<|eot_id|>'
|
||||
control_token = "<|end_of_text|>" if is_eos_test else "<|eot_id|>"
|
||||
if is_eos_ignore_test:
|
||||
# Length-based stop; expect full string
|
||||
ref_str = generation_string + 2 * control_token
|
||||
@@ -624,14 +677,15 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
else:
|
||||
# Stop token triggered but not in output
|
||||
ref_str = generation_string
|
||||
assert gen_string == ref_str, (f"{gen_string=}, {ref_str=}")
|
||||
assert gen_string == ref_str, f"{gen_string=}, {ref_str=}"
|
||||
|
||||
if do_logprobs:
|
||||
# Validate number of sample logprobs
|
||||
num_tokens = len(gen_tokens)
|
||||
num_logprobs = len(gen_logprobs)
|
||||
assert num_tokens == num_logprobs, (
|
||||
f"Token count ({num_tokens}) != logprobs count ({num_logprobs})")
|
||||
f"Token count ({num_tokens}) != logprobs count ({num_logprobs})"
|
||||
)
|
||||
|
||||
# Check requests are finished
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
@@ -639,22 +693,24 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||
@pytest.mark.parametrize("num_sample_logprobs",
|
||||
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
def test_stop_string(include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int], dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
def test_stop_string(
|
||||
include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int],
|
||||
dummy_test_vectors,
|
||||
):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
||||
if num_sample_logprobs else None,
|
||||
prompt_logprobs_raw=None)
|
||||
if num_sample_logprobs
|
||||
else None,
|
||||
prompt_logprobs_raw=None,
|
||||
)
|
||||
|
||||
# Make N requests.
|
||||
request_id_list = [
|
||||
f"request-{idx}"
|
||||
for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
]
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
@@ -675,7 +731,8 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=None,
|
||||
),
|
||||
pooling_params=None)
|
||||
pooling_params=None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@@ -715,7 +772,8 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
prompt_logprobs = request_output.prompt_logprobs
|
||||
logprobs = request_output.outputs[0].logprobs
|
||||
gen_cumulative_logprobs[request_id] = request_output.outputs[
|
||||
0].cumulative_logprob
|
||||
0
|
||||
].cumulative_logprob
|
||||
if request_id not in gen_strings:
|
||||
gen_strings[request_id] = new_text
|
||||
gen_tokens[request_id] = new_tokens
|
||||
@@ -733,8 +791,8 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
|
||||
# Confirmed tracked values matches what we expected.
|
||||
for idx, (ref_gen_str, stop_str) in enumerate(
|
||||
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)):
|
||||
|
||||
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
|
||||
):
|
||||
# Request should be aborted.
|
||||
request_id = f"request-{idx}"
|
||||
assert request_id in aborted
|
||||
@@ -748,24 +806,28 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str
|
||||
|
||||
if include_stop_str_in_output:
|
||||
assert gen_str == ref_str_inc_stop, (
|
||||
f"{gen_str=}, {ref_str_inc_stop=}")
|
||||
assert gen_str == ref_str_inc_stop, f"{gen_str=}, {ref_str_inc_stop=}"
|
||||
else:
|
||||
assert gen_str == ref_str_exc_stop, (
|
||||
f"{gen_str=}, {ref_str_exc_stop=}")
|
||||
assert gen_str == ref_str_exc_stop, f"{gen_str=}, {ref_str_exc_stop=}"
|
||||
|
||||
# Confirmed tracked logprobs match what we expect
|
||||
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs, dummy_test_vectors,
|
||||
request_id_list, num_sample_logprobs, None)
|
||||
_validate_logprobs(
|
||||
gen_tokens,
|
||||
gen_logprobs,
|
||||
gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs,
|
||||
dummy_test_vectors,
|
||||
request_id_list,
|
||||
num_sample_logprobs,
|
||||
None,
|
||||
)
|
||||
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
assert not output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
def test_iteration_stats(dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=True)
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
engine_core_timestamp = time.monotonic()
|
||||
|
||||
@@ -782,7 +844,8 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
# Add all requests except one to the OutputProcessor.
|
||||
@@ -794,12 +857,13 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
# First iteration has 2 prefills.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
total_prompt_tokens = sum([
|
||||
len(prompt_tokens)
|
||||
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
|
||||
])
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
|
||||
total_prompt_tokens = sum(
|
||||
[
|
||||
len(prompt_tokens)
|
||||
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
|
||||
]
|
||||
)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
@@ -807,8 +871,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
# Just decodes in this step.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == 0
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
@@ -818,8 +881,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
num_active += 1
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
|
||||
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||
@@ -828,8 +890,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
# Just decodes in this step.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == 0
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
@@ -853,16 +914,13 @@ async def test_request_output_collector():
|
||||
text=TEXT,
|
||||
token_ids=[idx],
|
||||
cumulative_logprob=(idx + 1 * 1.0),
|
||||
logprobs=[{
|
||||
"a": idx,
|
||||
"b": idx
|
||||
}],
|
||||
finish_reason="length" if
|
||||
(idx == NUM_REQS - 1) else None,
|
||||
logprobs=[{"a": idx, "b": idx}],
|
||||
finish_reason="length" if (idx == NUM_REQS - 1) else None,
|
||||
)
|
||||
],
|
||||
finished=(idx == NUM_REQS - 1),
|
||||
) for idx in range(NUM_REQS)
|
||||
)
|
||||
for idx in range(NUM_REQS)
|
||||
]
|
||||
|
||||
collector = RequestOutputCollector(RequestOutputKind.DELTA)
|
||||
@@ -888,8 +946,7 @@ async def test_request_output_collector():
|
||||
assert not output.finished
|
||||
# Text, token_ids, and logprobs should get merged.
|
||||
assert output.outputs[0].text == TEXT * num_to_put
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
|
||||
list(range(num_to_put))):
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))):
|
||||
assert tok_0 == tok_1
|
||||
assert len(output.outputs[0].logprobs) == num_to_put
|
||||
|
||||
@@ -910,8 +967,7 @@ async def test_request_output_collector():
|
||||
assert output.outputs[0].finish_reason == "length"
|
||||
# Text, token_ids, and logprobs should get merged.
|
||||
assert output.outputs[0].text == TEXT * num_to_put
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
|
||||
list(range(num_to_put))):
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))):
|
||||
assert tok_0 == tok_1
|
||||
assert len(output.outputs[0].logprobs) == num_to_put
|
||||
|
||||
@@ -1003,8 +1059,7 @@ async def test_cumulative_output_collector_n():
|
||||
|
||||
@pytest.mark.parametrize("runner", ["generate", "pooling"])
|
||||
def test_abort_requests(runner: str, dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=True)
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
@@ -1016,9 +1071,9 @@ def test_abort_requests(runner: str, dummy_test_vectors):
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams() if runner == "generate" else None,
|
||||
pooling_params=PoolingParams(
|
||||
task="embed") if runner == "pooling" else None,
|
||||
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
pooling_params=PoolingParams(task="embed") if runner == "pooling" else None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
for request in requests:
|
||||
|
||||
@@ -16,35 +16,33 @@ baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
|
||||
|
||||
|
||||
# Mock processor for testing
|
||||
def _mk_processor(monkeypatch,
|
||||
*,
|
||||
mm_cache_gb: float = 4.0,
|
||||
enable_prefix_caching: bool = True) -> Processor:
|
||||
def _mk_processor(
|
||||
monkeypatch, *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
|
||||
) -> Processor:
|
||||
"""
|
||||
Create a Processor instance with minimal configuration suitable for unit
|
||||
tests without accessing external resources.
|
||||
"""
|
||||
monkeypatch.setattr(ModelConfig,
|
||||
"try_get_generation_config",
|
||||
lambda self: {},
|
||||
raising=True)
|
||||
monkeypatch.setattr(ModelConfig,
|
||||
"__post_init__",
|
||||
lambda self, *args: None,
|
||||
raising=True)
|
||||
monkeypatch.setattr(ModelConfig,
|
||||
"verify_with_parallel_config",
|
||||
lambda self, parallel_config: None,
|
||||
raising=True)
|
||||
monkeypatch.setattr(processor_mod,
|
||||
"processor_cache_from_config",
|
||||
lambda vllm_config, mm_registry: None,
|
||||
raising=True)
|
||||
monkeypatch.setattr(
|
||||
ModelConfig, "try_get_generation_config", lambda self: {}, raising=True
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ModelConfig, "__post_init__", lambda self, *args: None, raising=True
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ModelConfig,
|
||||
"verify_with_parallel_config",
|
||||
lambda self, parallel_config: None,
|
||||
raising=True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
processor_mod,
|
||||
"processor_cache_from_config",
|
||||
lambda vllm_config, mm_registry: None,
|
||||
raising=True,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(VllmConfig,
|
||||
"__post_init__",
|
||||
lambda self: None,
|
||||
raising=True)
|
||||
monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True)
|
||||
|
||||
model_config = ModelConfig(
|
||||
skip_tokenizer_init=True,
|
||||
@@ -57,12 +55,10 @@ def _mk_processor(monkeypatch,
|
||||
# Minimal multimodal_config to satisfy references in
|
||||
# Processor.process_inputs.
|
||||
class _MockMMConfig:
|
||||
|
||||
def __init__(self, gb: float):
|
||||
self.mm_processor_cache_gb = gb
|
||||
|
||||
model_config.multimodal_config = _MockMMConfig(
|
||||
mm_cache_gb) # type: ignore[attr-defined]
|
||||
model_config.multimodal_config = _MockMMConfig(mm_cache_gb) # type: ignore[attr-defined]
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
|
||||
@@ -79,13 +75,9 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
|
||||
|
||||
prompt = {
|
||||
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image, stop_pil_image]
|
||||
},
|
||||
"multi_modal_data": {"image": [cherry_pil_image, stop_pil_image]},
|
||||
# Mismatch: 2 items but only 1 uuid provided
|
||||
"multi_modal_uuids": {
|
||||
"image": ["hash_cherry"]
|
||||
},
|
||||
"multi_modal_uuids": {"image": ["hash_cherry"]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="must have same length as data"):
|
||||
@@ -104,16 +96,13 @@ def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
|
||||
# Two modalities provided in data
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image],
|
||||
"video": [baby_reading_np_ndarrays]
|
||||
"video": [baby_reading_np_ndarrays],
|
||||
},
|
||||
# Only image uuids provided; video missing should raise
|
||||
"multi_modal_uuids": {
|
||||
"image": ["hash_cherry"]
|
||||
},
|
||||
"multi_modal_uuids": {"image": ["hash_cherry"]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="must be provided if multi_modal_data"):
|
||||
with pytest.raises(ValueError, match="must be provided if multi_modal_data"):
|
||||
processor.process_inputs(
|
||||
request_id="req-2",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
@@ -130,28 +119,28 @@ def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
|
||||
],
|
||||
)
|
||||
def test_multi_modal_uuids_accepts_none_and_passes_through(
|
||||
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool):
|
||||
processor = _mk_processor(monkeypatch,
|
||||
mm_cache_gb=mm_cache_gb,
|
||||
enable_prefix_caching=enable_prefix_caching)
|
||||
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
|
||||
):
|
||||
processor = _mk_processor(
|
||||
monkeypatch,
|
||||
mm_cache_gb=mm_cache_gb,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
|
||||
# Capture the overrides passed to InputPreprocessor.preprocess
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_preprocess(prompt,
|
||||
*,
|
||||
tokenization_kwargs=None,
|
||||
lora_request=None,
|
||||
mm_uuids=None):
|
||||
def fake_preprocess(
|
||||
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
|
||||
):
|
||||
captured["mm_uuids"] = mm_uuids
|
||||
# Minimal processed inputs for decoder-only flow
|
||||
return {"type": "token", "prompt_token_ids": [1]}
|
||||
|
||||
# Monkeypatch only the bound preprocess method on this instance
|
||||
monkeypatch.setattr(processor.input_preprocessor,
|
||||
"preprocess",
|
||||
fake_preprocess,
|
||||
raising=True)
|
||||
monkeypatch.setattr(
|
||||
processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
|
||||
)
|
||||
|
||||
# Use a consistent two-image scenario across all configurations
|
||||
mm_uuids = {"image": [None, "hash_stop"], "video": None}
|
||||
@@ -176,24 +165,19 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
|
||||
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
|
||||
# When both processor cache is 0 and prefix caching disabled, the
|
||||
# processor builds overrides from request id instead of using user UUIDs.
|
||||
processor = _mk_processor(monkeypatch,
|
||||
mm_cache_gb=0.0,
|
||||
enable_prefix_caching=False)
|
||||
processor = _mk_processor(monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_preprocess(prompt,
|
||||
*,
|
||||
tokenization_kwargs=None,
|
||||
lora_request=None,
|
||||
mm_uuids=None):
|
||||
def fake_preprocess(
|
||||
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
|
||||
):
|
||||
captured["mm_uuids"] = mm_uuids
|
||||
return {"type": "token", "prompt_token_ids": [1]}
|
||||
|
||||
monkeypatch.setattr(processor.input_preprocessor,
|
||||
"preprocess",
|
||||
fake_preprocess,
|
||||
raising=True)
|
||||
monkeypatch.setattr(
|
||||
processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
|
||||
)
|
||||
|
||||
request_id = "req-42"
|
||||
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"}
|
||||
|
||||
@@ -82,11 +82,12 @@ def _create_random_top_logprob_test_matrix(
|
||||
|
||||
|
||||
def _create_random_top_token_test_vector(
|
||||
num_logprobs: int,
|
||||
lower: int,
|
||||
upper: int,
|
||||
sampled_token_id: int,
|
||||
adjust_num_logprobs: bool = True) -> tuple[torch.Tensor, int]:
|
||||
num_logprobs: int,
|
||||
lower: int,
|
||||
upper: int,
|
||||
sampled_token_id: int,
|
||||
adjust_num_logprobs: bool = True,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Create a random vector of top logprob token indices
|
||||
|
||||
Use to create fake sample logprobs for testing. The sampled token
|
||||
@@ -127,8 +128,9 @@ def _create_random_top_token_test_vector(
|
||||
|
||||
# Check if the sampled_token_id occurs in choice_tensor[1:]
|
||||
if sampled_token_id in choice_tensor[1:]:
|
||||
sampled_token_rank = (choice_tensor[1:] == sampled_token_id).nonzero(
|
||||
as_tuple=True)[0].item()
|
||||
sampled_token_rank = (
|
||||
(choice_tensor[1:] == sampled_token_id).nonzero(as_tuple=True)[0].item()
|
||||
)
|
||||
else:
|
||||
# If not found, assign a random int between num_logprobs and 50700
|
||||
sampled_token_rank = random.randint(num_logprobs, 50700)
|
||||
@@ -164,9 +166,12 @@ def _create_random_top_token_test_matrix(
|
||||
num_elements = shape[0] * shape[1]
|
||||
choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower
|
||||
matrix = torch.cat(
|
||||
(torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1),
|
||||
choice_tensor.view(shape)),
|
||||
dim=1)
|
||||
(
|
||||
torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1),
|
||||
choice_tensor.view(shape),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Initialize the tensor for storing the ranks
|
||||
prompt_token_ranks = torch.empty(shape[0], dtype=torch.int)
|
||||
@@ -174,8 +179,7 @@ def _create_random_top_token_test_matrix(
|
||||
# Iterate over each row to check presence of
|
||||
# tokens_list[rdx] and determine its index
|
||||
for rdx in range(shape[0]):
|
||||
row = matrix[rdx,
|
||||
1:] # Skip the first column as it contains the token list
|
||||
row = matrix[rdx, 1:] # Skip the first column as it contains the token list
|
||||
token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0]
|
||||
if token_index.numel() > 0:
|
||||
prompt_token_ranks[rdx] = token_index.item()
|
||||
@@ -229,19 +233,21 @@ def generate_dummy_sample_logprobs(
|
||||
(
|
||||
token_vector,
|
||||
sampled_token_rank,
|
||||
) = _create_random_top_token_test_vector(num_logprobs, 0,
|
||||
len(tokenizer.vocab) - 1,
|
||||
sampled_token_id)
|
||||
) = _create_random_top_token_test_vector(
|
||||
num_logprobs, 0, len(tokenizer.vocab) - 1, sampled_token_id
|
||||
)
|
||||
|
||||
res.append(
|
||||
(token_vector,
|
||||
_create_random_top_logprob_test_vector(num_logprobs + 1, -100,
|
||||
0), sampled_token_rank))
|
||||
(
|
||||
token_vector,
|
||||
_create_random_top_logprob_test_vector(num_logprobs + 1, -100, 0),
|
||||
sampled_token_rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert tensors in the list tuples to Python lists
|
||||
res_list_format = [
|
||||
(log_probs_tensor.tolist(), token_ids_tensor.tolist(),
|
||||
sampled_token_rank)
|
||||
(log_probs_tensor.tolist(), token_ids_tensor.tolist(), sampled_token_rank)
|
||||
for log_probs_tensor, token_ids_tensor, sampled_token_rank in res
|
||||
]
|
||||
|
||||
@@ -282,18 +288,24 @@ def generate_dummy_prompt_logprobs_tensors(
|
||||
token_vector,
|
||||
prompt_token_ranks,
|
||||
) = _create_random_top_token_test_matrix(
|
||||
(num_prompt_logprobs, num_logprobs), 0,
|
||||
len(tokenizer.vocab) - 1, prompt_tokens_list[1:])
|
||||
(num_prompt_logprobs, num_logprobs),
|
||||
0,
|
||||
len(tokenizer.vocab) - 1,
|
||||
prompt_tokens_list[1:],
|
||||
)
|
||||
return LogprobsTensors(
|
||||
token_vector,
|
||||
_create_random_top_logprob_test_matrix(
|
||||
(num_prompt_logprobs, num_logprobs + 1), -100, 0),
|
||||
prompt_token_ranks)
|
||||
(num_prompt_logprobs, num_logprobs + 1), -100, 0
|
||||
),
|
||||
prompt_token_ranks,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyOutputProcessorTestVectors:
|
||||
"""Dummy test vectors for output processor tests"""
|
||||
|
||||
tokenizer: GeneralTokenizerType
|
||||
vllm_config: EngineArgs
|
||||
full_tokens: list[list[int]] # Prompt + generated tokens
|
||||
@@ -320,9 +332,9 @@ class MockEngineCore:
|
||||
# For each request, for each sampled token offset,
|
||||
# a tuple of
|
||||
# (list of topk token ids, list of sample logprob vals, rank)
|
||||
generated_logprobs_raw: Optional[list[list[tuple[list[int],
|
||||
list[float],
|
||||
int]]]] = None,
|
||||
generated_logprobs_raw: Optional[
|
||||
list[list[tuple[list[int], list[float], int]]]
|
||||
] = None,
|
||||
# For each request, a tuple of
|
||||
# (prompt logprob val matrix, prompt logprob tok id matrix);
|
||||
# each matrix has dimensions
|
||||
@@ -355,7 +367,8 @@ class MockEngineCore:
|
||||
if do_logprobs:
|
||||
assert self.generated_logprobs_raw is not None
|
||||
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
|
||||
self.generated_logprobs_raw[req_idx][token_idx])
|
||||
self.generated_logprobs_raw[req_idx][token_idx]
|
||||
)
|
||||
logprobs = LogprobsLists(
|
||||
[logprobs_token_ids_],
|
||||
[logprobs_],
|
||||
|
||||
Reference in New Issue
Block a user