Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501
SharedStorageConnectorMetadata)
SharedStorageConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized, get_kv_transfer_group)
ensure_kv_transfer_initialized,
get_kv_transfer_group,
)
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
# Importing utils registers TestSharedStorageConnector with the factory
from .utils import create_vllm_config
@@ -34,7 +36,7 @@ def test_kv_connector_mixin_clears_metadata():
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector"
vllm_config.kv_transfer_config.kv_role = "kv_both"
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = ("unit")
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
# Initialize the global connector instance
ensure_kv_transfer_initialized(vllm_config)
@@ -46,7 +48,8 @@ def test_kv_connector_mixin_clears_metadata():
# Invoke the no-forward path which uses the mixin context manager
KVConnectorModelRunnerMixin.kv_connector_no_forward(
scheduler_output, vllm_config)
scheduler_output, vllm_config
)
# Verify clear_connector_metadata was called on the connector
connector = get_kv_transfer_group()

View File

@@ -9,17 +9,19 @@ import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
from .utils import (create_model_runner_output, create_request,
create_scheduler, create_vllm_config)
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request,
_: int) -> tuple[int, bool]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
@@ -33,9 +35,7 @@ def scheduler():
@pytest.mark.parametrize(
"num_prompt_blocks,"
"num_external_computed_blocks,"
"invalid_block_idxs",
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
[
(100, 99, {0, 98}),
(100, 99, {50, 98}),
@@ -51,8 +51,7 @@ def test_async_load_failure(
assert num_prompt_blocks >= num_external_computed_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = (num_external_computed_blocks *
scheduler.block_size)
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
request1 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request1)
@@ -71,8 +70,8 @@ def test_async_load_failure(
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
async_load=True))
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
)
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
@@ -84,14 +83,14 @@ def test_async_load_failure(
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
# Simulate a failure in loading some of request2 blocks.
(req2_block_ids, ) = scheduler.kv_cache_manager.get_block_ids(
request2.request_id)
(req2_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request2.request_id)
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving={request1.request_id, request3.request_id},
invalid_block_ids=invalid_block_ids,
use_eos=True)
use_eos=True,
)
scheduler.update_from_output(scheduler_output, model_runner_output)
@@ -100,8 +99,9 @@ def test_async_load_failure(
assert len(scheduler.waiting) == 3
for request in scheduler.waiting:
if request.request_id == request2.request_id:
assert request.num_computed_tokens == (min_invalid_block_idx *
scheduler.block_size)
assert request.num_computed_tokens == (
min_invalid_block_idx * scheduler.block_size
)
else:
assert request.num_computed_tokens == 0
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
@@ -110,9 +110,7 @@ def test_async_load_failure(
@pytest.mark.parametrize(
"num_prompt_blocks,"
"num_external_computed_blocks,"
"invalid_block_idxs",
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
[
(100, 99, {0, 98}),
(100, 99, {50, 98}),
@@ -128,8 +126,7 @@ def test_sync_load_failure(
assert num_prompt_blocks >= num_external_computed_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = (num_external_computed_blocks *
scheduler.block_size)
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
request1 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request1)
@@ -148,8 +145,8 @@ def test_sync_load_failure(
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
async_load=False))
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
)
scheduler.connector.request_finished.return_value = (False, None)
scheduler.connector.take_events.return_value = ()
@@ -165,8 +162,7 @@ def test_sync_load_failure(
assert len(scheduler.running) == 3
assert len(scheduler_output.scheduled_new_reqs) == 3
for request in scheduler_output.scheduled_new_reqs:
assert request.num_computed_tokens == expected_computed_tokens[
request.req_id]
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
# Simulate a failure in loading some of request2 blocks.
@@ -175,14 +171,16 @@ def test_sync_load_failure(
model_runner_output = create_model_runner_output(
[request1, request2, request3],
invalid_block_ids=invalid_block_ids,
use_eos=True)
use_eos=True,
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == request2.request_id
assert scheduler.running[0].num_computed_tokens == (
min(invalid_block_idxs) * scheduler.block_size)
min(invalid_block_idxs) * scheduler.block_size
)
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
assert scheduler.connector.request_finished.call_count == 2
@@ -205,19 +203,19 @@ def test_sync_load_failure_with_shared_blocks(
num_common_prefix_blocks: int,
invalid_block_idxs: set[int],
):
assert (num_prompt_blocks >= num_external_computed_blocks >=
num_common_prefix_blocks)
assert num_prompt_blocks >= num_external_computed_blocks >= num_common_prefix_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = (num_external_computed_blocks *
scheduler.block_size)
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
common_prefix_len = num_common_prefix_blocks * scheduler.block_size
request1 = create_request(num_tokens=num_prompt_tokens,
common_prefix_len=common_prefix_len)
request1 = create_request(
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
)
scheduler.add_request(request=request1)
request2 = create_request(num_tokens=num_prompt_tokens,
common_prefix_len=common_prefix_len)
request2 = create_request(
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
)
scheduler.add_request(request=request2)
# Mock KV connector method.
@@ -228,8 +226,8 @@ def test_sync_load_failure_with_shared_blocks(
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
async_load=False))
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
)
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
@@ -243,17 +241,15 @@ def test_sync_load_failure_with_shared_blocks(
assert len(scheduler.running) == 2
assert len(scheduler_output.scheduled_new_reqs) == 2
for request in scheduler_output.scheduled_new_reqs:
assert request.num_computed_tokens == expected_computed_tokens[
request.req_id]
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
# Simulate a failure in loading some of the shared blocks.
req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs}
model_runner_output = create_model_runner_output(
[request1, request2],
invalid_block_ids=invalid_block_ids,
use_eos=True)
[request1, request2], invalid_block_ids=invalid_block_ids, use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
@@ -266,15 +262,14 @@ def test_sync_load_failure_with_shared_blocks(
assert len(scheduler.running) == 2
for request in scheduler.running:
assert request.num_computed_tokens == expected_computed_tokens[
request.request_id]
assert (
request.num_computed_tokens == expected_computed_tokens[request.request_id]
)
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
@pytest.mark.parametrize(
"num_prompt_blocks,"
"num_external_computed_blocks,"
"invalid_block_idxs",
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
[
(100, 99, {0, 50, 98}),
(100, 99, {98, 50, 0}),
@@ -289,8 +284,7 @@ def test_async_progressive_load_failure(
assert num_prompt_blocks >= num_external_computed_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = (num_external_computed_blocks *
scheduler.block_size)
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
request = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request)
@@ -303,8 +297,8 @@ def test_async_progressive_load_failure(
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
async_load=True))
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
)
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
@@ -318,24 +312,24 @@ def test_async_progressive_load_failure(
min_invalid_block_idx = max(invalid_block_idxs) + 1
# Simulate failures when progressively loading request blocks.
for invalid_block_idx in invalid_block_idxs:
(req_block_ids, ) = scheduler.kv_cache_manager.get_block_ids(
request.request_id)
(req_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=set(),
invalid_block_ids=invalid_block_ids,
use_eos=True)
use_eos=True,
)
scheduler.update_from_output(scheduler_output, model_runner_output)
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request(
).request_id == request.request_id
assert request.num_computed_tokens == (min_invalid_block_idx *
scheduler.block_size)
assert scheduler.waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == (
min_invalid_block_idx * scheduler.block_size
)
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.failed_recving_kv_req_ids == {request.request_id}
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1

View File

@@ -52,29 +52,26 @@ def test_multi_shared_storage_connector_consistency():
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [{
"kv_connector":
"TestSharedStorageConnector",
"kv_role":
"kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_1_path),
"name": "storage1",
"connectors": [
{
"kv_connector": "TestSharedStorageConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_1_path),
"name": "storage1",
},
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
"kv_connector_module_path":
"tests.v1.kv_connector.unit.utils",
}, {
"kv_connector":
"TestSharedStorageConnector",
"kv_role":
"kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_2_path),
"name": "storage2",
{
"kv_connector": "TestSharedStorageConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_2_path),
"name": "storage2",
},
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
"kv_connector_module_path":
"tests.v1.kv_connector.unit.utils",
}]
]
},
)
@@ -93,14 +90,16 @@ def test_multi_shared_storage_connector_consistency():
local_subdirs = list(storage_1_path.iterdir())
external_subdirs = list(storage_2_path.iterdir())
assert len(
local_subdirs
) > 0, f"Local storage path {storage_1_path} is empty after generation."
assert len(local_subdirs) > 0, (
f"Local storage path {storage_1_path} is empty after generation."
)
assert len(external_subdirs) > 0, (
f"External storage path {storage_2_path} is empty after generation.")
f"External storage path {storage_2_path} is empty after generation."
)
assert len(local_subdirs) == len(external_subdirs), (
f"Mismatch in number of cache entries: "
f"Local={len(local_subdirs)}, External={len(external_subdirs)}")
f"Local={len(local_subdirs)}, External={len(external_subdirs)}"
)
# The subdirectories should correspond to the prompt hashes
# Since prompts are the same, the hash directories should be the same name
@@ -113,29 +112,39 @@ def test_multi_shared_storage_connector_consistency():
# Compare the contents of each corresponding cache directory
for subdir_name in local_subdir_names:
print(f"Comparing contents of cache directory: {subdir_name}")
assert _compare_directories(storage_1_path / subdir_name,
storage_2_path / subdir_name), \
(f"Contents differ for cache directory '{subdir_name}' between "
f"{storage_1_path} and {storage_2_path}")
assert _compare_directories(
storage_1_path / subdir_name, storage_2_path / subdir_name
), (
f"Contents differ for cache directory '{subdir_name}' between "
f"{storage_1_path} and {storage_2_path}"
)
events = get_connector_events()
# get_num_new_matched_tokens and update_state_after_alloc will be called
# on each connector in turn.
assert events["storage1-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta'
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage1-WORKER"][:5] == [
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
'wait_for_layer_load', 'save_kv_layer'
"register_kv_caches",
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
"save_kv_layer",
]
assert events["storage2-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta'
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage2-WORKER"][:5] == [
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
'wait_for_layer_load', 'save_kv_layer'
"register_kv_caches",
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
"save_kv_layer",
]
# Reset prefix cache or else we'll just get the tokens back from there.
@@ -151,12 +160,14 @@ def test_multi_shared_storage_connector_consistency():
# on that one but with zero blocks for others (first nonzero match is
# chosen).
assert events["storage1-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=[7] 96', 'build_connector_meta'
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[7] 96",
"build_connector_meta",
]
assert events["storage2-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta'
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
# Delete storage1 connector state
@@ -175,12 +186,14 @@ def test_multi_shared_storage_connector_consistency():
# a hit, so update_state_after_alloc will only be called with allocated
# blocks for the second connector.
assert events["storage1-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta'
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage2-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=[7] 96', 'build_connector_meta'
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[7] 96",
"build_connector_meta",
]
# Clean up
@@ -191,15 +204,14 @@ def test_multi_shared_storage_connector_consistency():
def get_connector_events() -> dict[str, list[str]]:
# Read in connector events and reset the files.
import glob
event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log")
connector_events = {}
for fname in event_files:
name = fname.split("connector_")[1].split("_events.log")[0]
try:
with open(fname, "r+") as f:
connector_events[name] = [
line.strip() for line in f if line.strip()
]
connector_events[name] = [line.strip() for line in f if line.strip()]
f.truncate(0)
except Exception as e:
print(f"[ERROR] Could not read connector events for {name}: {e}")
@@ -211,5 +223,5 @@ def test_engine_id_conflict():
configs = [KVTransferConfig() for _ in range(2)]
ids = [config.engine_id for config in configs]
assert ids[0] != ids[1], (
"Engine IDs should be different for different configs. "
f"Got {ids}")
f"Engine IDs should be different for different configs. Got {ids}"
)

View File

@@ -19,15 +19,22 @@ import torch
from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats)
MultiKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker, NixlKVConnectorStats)
KVConnectorRole,
NixlAgentMetadata,
NixlConnector,
NixlConnectorMetadata,
NixlConnectorWorker,
NixlKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_shutdown, has_kv_transfer_group)
ensure_kv_transfer_shutdown,
has_kv_transfer_group,
)
from vllm.forward_context import ForwardContext
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
@@ -42,14 +49,14 @@ from .utils import create_request, create_scheduler, create_vllm_config
def clear_kv_transfer():
"""
The test cases in this file use `VLLM_ENABLE_V1_MULTIPROCESSING=0`,
causing the global variable `_KV_CONNECTOR_AGENT`
causing the global variable `_KV_CONNECTOR_AGENT`
to be assigned but never deleted.
Since the current pytest process does not terminate and instead
Since the current pytest process does not terminate and instead
continues running tests from other files,
this global variable remains in memory and interferes
this global variable remains in memory and interferes
with test cases in other modules.
So we use this fixture to ensure that the global variable
`_KV_CONNECTOR_AGENT` is properly cleaned up after each test.
"""
@@ -58,11 +65,12 @@ def clear_kv_transfer():
ensure_kv_transfer_shutdown()
def get_default_xfer_telemetry(xferDurationS: float = 1,
postDurationS: float = 1,
totalBytes: int = 1,
descCount: int = 1) -> dict:
def get_default_xfer_telemetry(
xferDurationS: float = 1,
postDurationS: float = 1,
totalBytes: int = 1,
descCount: int = 1,
) -> dict:
class AttributeDict(dict):
__slots__ = ()
__getattr__ = dict.__getitem__
@@ -83,7 +91,7 @@ class FakeNixlWrapper:
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
Note: The complete source of this class is also used in the
`_make_fake_nixl_pkg` function to create a fake nixl package
for Ray workers.
@@ -94,8 +102,7 @@ class FakeNixlWrapper:
def __init__(self, agent_name: str, *args, **kwargs):
self._cycles_before_xfer_done = 0
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
lambda: 0)
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(lambda: 0)
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]
@@ -123,8 +130,7 @@ class FakeNixlWrapper:
return {}
def check_xfer_state(self, handle: int) -> str:
if self._check_xfer_state_cycles[
handle] >= self._cycles_before_xfer_done:
if self._check_xfer_state_cycles[handle] >= self._cycles_before_xfer_done:
return "DONE"
self._check_xfer_state_cycles[handle] += 1
return "PROC"
@@ -141,13 +147,15 @@ class FakeNixlWrapper:
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass
def make_prepped_xfer(self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: Optional[bytes] = None) -> int:
def make_prepped_xfer(
self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: Optional[bytes] = None,
) -> int:
return uuid.uuid4().int
def transfer(self, handle: int) -> str:
@@ -168,7 +176,7 @@ class FakeNixlWrapper:
def _make_fake_nixl_pkg():
"""Context manager that creates a temporary package making
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
Automatically cleans up the temporary directory when done.
"""
with tempfile.TemporaryDirectory() as td:
@@ -214,10 +222,12 @@ def test_basic_interface():
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
)
request_id = request.request_id
scheduler.add_request(request)
@@ -233,8 +243,11 @@ def test_basic_interface():
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
for block_id, block in zip(
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id]):
req_meta.local_block_ids,
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
):
assert block_id == block.block_id
@@ -254,11 +267,13 @@ def test_prompt_less_than_block_size():
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
# Request will have 1 partial remote block.
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
num_remote_blocks=1)
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
num_remote_blocks=1,
)
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
@@ -271,15 +286,15 @@ def test_prompt_less_than_block_size():
class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
expected_engine_id: str) -> dict[int, str]:
def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
@@ -304,21 +319,23 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# is started. We mock HND here.
kv_cache_layout="HND",
),
remote_tp_size=remote_tp_size)
remote_tp_size=remote_tp_size,
)
return {0: remote_agent_name}
class TestNixlHandshake:
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
FakeNixlWrapper,
)
def test_multi_xfer_one_engine(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
dist_init,
):
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are
@@ -332,9 +349,9 @@ class TestNixlHandshake:
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
assert isinstance(connector.connector_worker.nixl_wrapper,
FakeNixlWrapper)
vllm_config, connector.engine_id, hand_shake_latency=0
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
num_xfers = 4
while True:
@@ -345,21 +362,19 @@ class TestNixlHandshake:
num_xfers -= 1
metadata.add_new_req(
request_id=request_id,
local_block_ids=[
num_xfers + 1, num_xfers + 2, num_xfers + 3
],
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
kv_transfer_params={
"remote_block_ids":
[num_xfers + 4, num_xfers + 5, num_xfers + 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host":
"localhost",
"remote_port":
1234,
"remote_tp_size":
1,
})
"remote_block_ids": [
num_xfers + 4,
num_xfers + 5,
num_xfers + 6,
],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
},
)
connector.bind_connector_metadata(metadata)
# Mimic maybe_setup_kv_connector in gpu_model_runner.
@@ -371,8 +386,9 @@ class TestNixlHandshake:
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
assert _after_load - _before_load < 0.1, (
f"start_load_kv took {_after_load - _before_load} seconds"
)
# Mimic get_finished_kv_transfers in gpu_model_runner.
_, done_recving = connector.get_finished(finished_req_ids=set())
@@ -384,20 +400,25 @@ class TestNixlHandshake:
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
@pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [
(1, 1),
(2, 1),
(4, 2),
(4, 4),
])
FakeNixlWrapper,
)
@pytest.mark.parametrize(
"decode_tp_size, prefill_tp_size",
[
(1, 1),
(2, 1),
(4, 2),
(4, 4),
],
)
def test_async_load_kv(
self,
# Fixture that initializes the distributed environment.
dist_init,
# Simulate consumer-producer TP sizes.
decode_tp_size,
prefill_tp_size):
self,
# Fixture that initializes the distributed environment.
dist_init,
# Simulate consumer-producer TP sizes.
decode_tp_size,
prefill_tp_size,
):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config = create_vllm_config()
@@ -406,18 +427,20 @@ class TestNixlHandshake:
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
vllm_config, connector.engine_id
)
metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id="id",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": prefill_tp_size,
})
metadata.add_new_req(
request_id="id",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": prefill_tp_size,
},
)
connector.bind_connector_metadata(metadata)
timeout = 2.5
@@ -431,8 +454,9 @@ class TestNixlHandshake:
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
assert _after_load - _before_load < 0.1, (
f"start_load_kv took {_after_load - _before_load} seconds"
)
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
@@ -442,11 +466,13 @@ class TestNixlHandshake:
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
FakeNixlWrapper,
)
def test_concurrent_load_kv(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
dist_init,
):
"""Test that multiple start_load_kv calls should occur concurrently."""
vllm_config = create_vllm_config()
@@ -454,20 +480,22 @@ class TestNixlHandshake:
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
vllm_config, connector.engine_id
)
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
metadata.add_new_req(request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
})
metadata.add_new_req(
request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
},
)
connector.bind_connector_metadata(metadata)
timeout = 2.5 * total_reqs
@@ -482,8 +510,9 @@ class TestNixlHandshake:
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
assert _after_load - _before_load < 0.1, (
f"start_load_kv took {_after_load - _before_load} seconds"
)
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
@@ -495,7 +524,8 @@ class TestNixlHandshake:
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
FakeNixlWrapper,
)
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
@@ -506,12 +536,14 @@ class TestNixlHandshake:
# Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1
with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
@@ -521,8 +553,7 @@ class TestNixlHandshake:
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
# Metadata with different kv_cache_layout than local worker
mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \
else "NHD"
mismatched_layout = "HND" if worker.kv_cache_layout != "HND" else "NHD"
meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
@@ -545,16 +576,17 @@ class TestNixlHandshake:
# the rest of the tests.
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
FakeNixlWrapper,
)
def test_kv_connector_stats(dist_init):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
# Verify that xfer_stats starts empty
initial_stats = connector.get_kv_connector_stats()
@@ -563,16 +595,17 @@ def test_kv_connector_stats(dist_init):
# Create transfer metadata
request_id = "test_req_for_stats"
metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
})
metadata.add_new_req(
request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
},
)
connector.bind_connector_metadata(metadata)
# Start the transfer
@@ -593,8 +626,7 @@ def test_kv_connector_stats(dist_init):
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0 and request_id in done_recving:
break
time.sleep(
0.1) # Small delay to allow background handshake to complete
time.sleep(0.1) # Small delay to allow background handshake to complete
else:
assert "Transfer did not complete within expected iterations"
@@ -613,7 +645,7 @@ def test_kv_connector_stats(dist_init):
def test_kv_connector_stats_aggregation():
"""
Test KV transfer stats aggregation across TP ranks using
Test KV transfer stats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
@@ -636,18 +668,16 @@ def test_kv_connector_stats_aggregation():
worker2_stats.record_transfer(stats)
# Worker 3: 3 transfers
stats = get_default_xfer_telemetry(xferDurationS=2,
postDurationS=2,
totalBytes=2,
descCount=2)
stats = get_default_xfer_telemetry(
xferDurationS=2, postDurationS=2, totalBytes=2, descCount=2
)
worker3_stats.record_transfer(stats)
worker3_stats.record_transfer(stats)
worker3_stats.record_transfer(stats)
# Create ModelRunnerOutput instances for each worker
worker_outputs = []
for i, worker_stats in enumerate(
[worker1_stats, worker2_stats, worker3_stats]):
for i, worker_stats in enumerate([worker1_stats, worker2_stats, worker3_stats]):
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
@@ -657,17 +687,19 @@ def test_kv_connector_stats_aggregation():
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"])
if i < 2 else None, # Workers 0,1 finished sending
if i < 2
else None, # Workers 0,1 finished sending
finished_recving=set([f"req_{i}_recv"])
if i > 0 else None, # Workers 1,2 finished receiving
if i > 0
else None, # Workers 1,2 finished receiving
kv_connector_stats=worker_stats,
))
),
)
worker_outputs.append(output)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_connector_stats = \
aggregated_output.kv_connector_output.kv_connector_stats
kv_connector_stats = aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
# Number of total transfers across all workers.
assert kv_connector_stats.num_successful_transfers == 6
@@ -691,7 +723,6 @@ def test_multi_kv_connector_stats_aggregation():
# Mock a KVConnectorStats class for testing aggregation over connectors.
@dataclass
class FooKVConnectorStats(KVConnectorStats):
def reset(self):
self.data = {"num_foo_transfers": 0}
@@ -703,15 +734,12 @@ def test_multi_kv_connector_stats_aggregation():
def is_empty(self) -> bool:
return self.data["num_foo_transfers"] == 0
def aggregate(self,
other: "FooKVConnectorStats") -> "FooKVConnectorStats":
def aggregate(self, other: "FooKVConnectorStats") -> "FooKVConnectorStats":
if not other.is_empty():
self.data["num_foo_transfers"] += other.data[
"num_foo_transfers"]
self.data["num_foo_transfers"] += other.data["num_foo_transfers"]
return self
def make_multi_stats(nixl_count: int,
foo_count: int) -> MultiKVConnectorStats:
def make_multi_stats(nixl_count: int, foo_count: int) -> MultiKVConnectorStats:
data: dict[str, KVConnectorStats] = {}
if nixl_count > 0:
nixl_stats = NixlKVConnectorStats()
@@ -747,13 +775,11 @@ def test_multi_kv_connector_stats_aggregation():
worker_outputs.append(output)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_connector_stats = \
aggregated_output.kv_connector_output.kv_connector_stats
kv_connector_stats = aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
# Validate per-connector totals across workers
assert isinstance(kv_connector_stats["NixlConnector"],
NixlKVConnectorStats)
assert isinstance(kv_connector_stats["NixlConnector"], NixlKVConnectorStats)
assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5
assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats)
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
@@ -762,11 +788,12 @@ def test_multi_kv_connector_stats_aggregation():
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
FakeNixlWrapper,
)
def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
"""
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
-----> P
-----> P
| {process request}
<-/--- | {result is NOT delivered, eg proxy is down}
|
@@ -823,39 +850,38 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=1,
extra_args={"kv_transfer_params": remote_prefill_opts})
extra_args={"kv_transfer_params": remote_prefill_opts},
)
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks
0
].req_to_blocks
padding = "Just making this request a little longer so that we're sure "
"we're not hitting the small-request lower bound beneath which we don't "
"actually trigger the whole kv transfer, but rather just recompute the "
"blocks on D."
_ = llm.generate([f"What is the capital of Japan? {padding}"],
sampling_params)
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
# Request finished but not freed
assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks
# Some other request, 0 still not freed
_ = llm.generate([f"What is the capital of Italy? {padding}"],
sampling_params)
assert '0' in req_to_blocks
assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks
_ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
assert "0" in req_to_blocks
assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks
# Wait for timeout and trigger another scheduler loop
time.sleep(timeout)
_ = llm.generate([f"What is the capital of France? {padding}"],
sampling_params)
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert '0' not in req_to_blocks
assert "0" not in req_to_blocks
def test_register_kv_caches(dist_init):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
This test verifies:
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
tensor metadata
@@ -866,10 +892,9 @@ def test_register_kv_caches(dist_init):
vllm_config = create_vllm_config()
# Create test kv cache tensors using proper backend shape
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2,
block_size=16,
num_kv_heads=4,
head_size=64)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
@@ -879,21 +904,30 @@ def test_register_kv_caches(dist_init):
}
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size(
) * shared_tensor[0].numel()
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr()
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
) as mock_nixl_wrapper,
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
),
): # noqa: E501
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
vllm_config, connector.engine_id, hand_shake_latency=0
)
# Get the mock instance
mock_wrapper_instance = mock_nixl_wrapper.return_value
@@ -909,12 +943,13 @@ def test_register_kv_caches(dist_init):
for i, cache_entry in enumerate(caches_data):
base_addr, size, _tp_rank, _ = cache_entry
assert size == expected_tensor_size, \
f"Entry {i}: Expected tensor size {expected_tensor_size}, " \
f"got {size}"
assert base_addr == expected_base_addrs[i], \
f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \
assert size == expected_tensor_size, (
f"Entry {i}: Expected tensor size {expected_tensor_size}, got {size}"
)
assert base_addr == expected_base_addrs[i], (
f"Entry {i}: Expected base address {expected_base_addrs[i]}, "
f"got {base_addr}"
)
# Verify get_xfer_descs was called with blocks_data
assert mock_wrapper_instance.get_xfer_descs.called
@@ -922,16 +957,17 @@ def test_register_kv_caches(dist_init):
# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, \
f"Expected {expected_blocks_count} blocks, " \
f"got {len(blocks_data)}"
assert len(blocks_data) == expected_blocks_count, (
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
)
expected_block_len = expected_tensor_size // 2
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
assert block_len == expected_block_len, \
f"Block entry {i}: Expected block len {expected_block_len}, " \
assert block_len == expected_block_len, (
f"Block entry {i}: Expected block len {expected_block_len}, "
f"got {block_len}"
)
class FakePlatform(Platform):
@@ -940,24 +976,26 @@ class FakePlatform(Platform):
@classmethod
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
"""
Returns a mapping from device_type to a tuple of supported
Returns a mapping from device_type to a tuple of supported
kv_buffer_device for nixl.
"""
return {'oot': ('oot', )}
return {"oot": ("oot",)}
@classmethod
def get_nixl_memory_type(cls) -> Optional[str]:
"""
Returns the nixl memory type for the current platform.
"""
return 'VRAM'
return "VRAM"
@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [
("oot", "VRAM"),
])
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
nixl_memory_type):
@pytest.mark.parametrize(
"kv_buffer_device, nixl_memory_type",
[
("oot", "VRAM"),
],
)
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory_type):
"""
Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper.
@@ -966,15 +1004,30 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
# Override the default memory types in the config
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
_NIXL_SUPPORTED_DEVICE)
_NIXL_SUPPORTED_DEVICE,
)
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform",
FakePlatform,
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE",
_NIXL_SUPPORTED_DEVICE,
),
): # noqa: E501
# Create connector and replace its worker with a fake one for isolation
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
@@ -985,22 +1038,23 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
FakeNixlWrapper,
)
def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()
worker = NixlConnectorWorker(vllm_config,
vllm_config.kv_transfer_config.engine_id)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper
with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \
patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \
patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \
patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \
patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \
patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg:
with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
):
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
@@ -1028,7 +1082,8 @@ def test_shutdown_cleans_up_resources(dist_init):
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
FakeNixlWrapper,
)
def test_aborted_request_removed_from_worker_in_batch(dist_init):
"""
Create and schedule a request so that P adds it to in-batch tracking via
@@ -1040,9 +1095,9 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
scheduler = create_scheduler(vllm_config)
# KVConnector Worker in P
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
# Create a request that triggers do_remote_decode so that
# the scheduler adds it to reqs_in_batch

View File

@@ -14,27 +14,42 @@ from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_events import BlockRemoved, BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
OffloadingConnector, OffloadingConnectorMetadata)
OffloadingConnector,
OffloadingConnectorMetadata,
)
from vllm.forward_context import ForwardContext
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (BlockHash, get_request_block_hasher,
init_none_hash)
from vllm.v1.core.kv_cache_utils import (
BlockHash,
get_request_block_hasher,
init_none_hash,
)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
OffloadingManager, PrepareStoreOutput)
from vllm.v1.kv_offload.abstract import (
LoadStoreSpec,
OffloadingEvent,
OffloadingManager,
PrepareStoreOutput,
)
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
TransferResult, TransferSpec)
from vllm.v1.kv_offload.worker.worker import (
OffloadingHandler,
TransferResult,
TransferSpec,
)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import Request
from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler,
create_vllm_config)
from .utils import (
EOS_TOKEN_ID,
create_model_runner_output,
create_scheduler,
create_vllm_config,
)
class MockLoadStoreSpec(LoadStoreSpec):
def __init__(self, block_hashes: Iterable[BlockHash]):
self.block_hashes: list[BlockHash] = list(block_hashes)
@@ -47,7 +62,6 @@ class MockLoadStoreSpec(LoadStoreSpec):
class MockOffloadingHandler(OffloadingHandler):
def __init__(self):
self.completed_transfers: list[TransferResult] = []
self.completed_specs: list[TransferSpec] = []
@@ -64,14 +78,14 @@ class MockOffloadingHandler(OffloadingHandler):
class MockOffloadingSpec(OffloadingSpec):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0
self.manager.prepare_load = lambda block_hashes: (MockLoadStoreSpec(
block_hashes))
self.manager.prepare_load = lambda block_hashes: (
MockLoadStoreSpec(block_hashes)
)
self.handler = MockOffloadingHandler()
def get_manager(self) -> OffloadingManager:
@@ -79,9 +93,7 @@ class MockOffloadingSpec(OffloadingSpec):
def get_handlers(
self, _
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
OffloadingHandler]]:
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
@@ -98,35 +110,35 @@ class TransferSummary:
class RequestRunner:
def __init__(self, offloaded_block_size: int, gpu_block_size: int,
num_gpu_blocks: int):
def __init__(
self, offloaded_block_size: int, gpu_block_size: int, num_gpu_blocks: int
):
self.offloaded_block_size: int = offloaded_block_size
self.gpu_block_size: int = gpu_block_size
self.num_gpu_blocks: int = num_gpu_blocks
self.req_id: int = -1
vllm_config = create_vllm_config(block_size=gpu_block_size,
max_num_batched_tokens=1000)
vllm_config = create_vllm_config(
block_size=gpu_block_size, max_num_batched_tokens=1000
)
vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="OffloadingConnector",
kv_role="kv_both",
kv_connector_extra_config={
"spec_name": "MockOffloadingSpec",
"spec_module_path":
"tests.v1.kv_connector.unit.test_offloading_connector",
"spec_module_path": "tests.v1.kv_connector.unit.test_offloading_connector",
"block_size": offloaded_block_size,
})
},
)
self.scheduler: Scheduler = create_scheduler(vllm_config,
num_blocks=num_gpu_blocks)
self.worker_connector = OffloadingConnector(vllm_config,
KVConnectorRole.WORKER)
self.scheduler: Scheduler = create_scheduler(
vllm_config, num_blocks=num_gpu_blocks
)
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)
# register worker kv_caches to enable OffloadingWorker creations
self.worker_connector.register_kv_caches(
kv_caches={"a": torch.empty(0)})
self.worker_connector.register_kv_caches(kv_caches={"a": torch.empty(0)})
# extract connector of scheduler
scheduler_connector = self.scheduler.connector
@@ -166,9 +178,9 @@ class RequestRunner:
init_none_hash(sha256)
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={},
attn_metadata={},
virtual_engine=0)
self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0
)
def new_request(self, token_ids: list[int]):
assert not self.scheduler.requests
@@ -189,8 +201,7 @@ class RequestRunner:
block_size_factor = self.offloaded_block_size // self.gpu_block_size
while self.pending_loads_count or self.pending_stores_count:
for transfer_spec in (
self.offloading_spec.get_completed_transfers()):
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec
if isinstance(src_spec, GPULoadStoreSpec):
@@ -207,8 +218,7 @@ class RequestRunner:
gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(
self.gpu_block_index[block_id.item()])
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
@@ -220,23 +230,26 @@ class RequestRunner:
assert len(gpu_block_indices) == len(offload_addresses)
self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses))
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_stores_count -= 1
else:
remainder_sub_block_count = (len(offload_addresses) -
len(gpu_block_indices))
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[
remainder_sub_block_count:]
offload_addresses = offload_addresses[remainder_sub_block_count:]
self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses))
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_loads_count -= 1
def _update_gpu_block_idx(self):
for blocks in (self.scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks.values()):
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks.values():
for block_idx, block in enumerate(blocks):
self.gpu_block_index[block.block_id] = block_idx
@@ -259,23 +272,20 @@ class RequestRunner:
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata,
OffloadingConnectorMetadata)
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
self.pending_stores_count += len(
kv_connector_metadata.reqs_to_store)
self.pending_stores_count += len(kv_connector_metadata.reqs_to_store)
self.worker_connector.bind_connector_metadata(
kv_connector_metadata)
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save()
finished_sending, finished_recving = (
self.worker_connector.get_finished(
scheduler_output.finished_req_ids))
finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids
)
self.worker_connector.clear_connector_metadata()
@@ -283,13 +293,13 @@ class RequestRunner:
reqs=self.scheduler.running,
finished_sending=finished_sending,
finished_recving=finished_recving,
token_id=token_id)
token_id=token_id,
)
if self.scheduler.running:
token_id = next(tokens_iter, None)
self.scheduler.update_from_output(scheduler_output,
model_runner_output)
self.scheduler.update_from_output(scheduler_output, model_runner_output)
self._wait_for_transfers()
@@ -300,24 +310,24 @@ class RequestRunner:
while self.scheduler.requests:
scheduler_output = self.scheduler.schedule()
finished_sending, finished_recving = (
self.worker_connector.get_finished(
scheduler_output.finished_req_ids))
finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids
)
assert not finished_recving
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending)
finished_sending=finished_sending
)
self.scheduler.update_from_output(scheduler_output,
model_runner_output)
self.scheduler.update_from_output(scheduler_output, model_runner_output)
def run(
self,
decoded_tokens: list[int],
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
self,
decoded_tokens: list[int],
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
):
"""
Runs multiple engine (scheduler + worker) steps.
@@ -337,23 +347,23 @@ class RequestRunner:
loaded_gpu_block_indexes: set[int] = set()
for transfer in self.completed_loads:
for gpu_block_idx, offloaded_address in zip(
transfer.gpu_block_indices, transfer.offload_addresses):
transfer.gpu_block_indices, transfer.offload_addresses
):
loaded_gpu_block_indexes.add(gpu_block_idx)
assert gpu_block_idx == self.offloaded[offloaded_address]
assert (
set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes)
assert set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes
self.completed_loads.clear()
stored_gpu_block_indexes: set[int] = set()
for transfer in self.completed_stores:
for gpu_block_idx, offloaded_address in zip(
transfer.gpu_block_indices, transfer.offload_addresses):
transfer.gpu_block_indices, transfer.offload_addresses
):
stored_gpu_block_indexes.add(gpu_block_idx)
self.offloaded[offloaded_address] = gpu_block_idx
assert (
set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes)
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
self.completed_stores.clear()
@@ -362,9 +372,11 @@ def request_runner():
runners = []
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
runner = RequestRunner(offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks)
runner = RequestRunner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
runners.append(runner)
return runner
@@ -386,15 +398,18 @@ def test_offloading_connector(request_runner):
num_gpu_blocks = 100
block_size_factor = offloaded_block_size // gpu_block_size
runner = request_runner(offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks)
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
# 3 blocks, store just the middle block (skip first and last)
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
runner.manager.prepare_store.side_effect = \
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
)
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
# add block missing 1 token -> no offload
@@ -402,21 +417,24 @@ def test_offloading_connector(request_runner):
runner.manager.prepare_store.assert_not_called()
# +1 token -> single block, fail prepare_store
runner.manager.prepare_store.side_effect = \
lambda block_hashes: None
runner.manager.prepare_store.side_effect = lambda block_hashes: None
runner.run(decoded_tokens=[0])
runner.manager.prepare_store.assert_called()
# 1 more block, now set block_hashes_to_store = []
runner.manager.prepare_store.side_effect = \
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.run(decoded_tokens=[0] * offloaded_block_size)
# 1 more block, now check touch was called with all 6 blocks
runner.manager.prepare_store.side_effect = \
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
runner.run(decoded_tokens=[0] * offloaded_block_size,
expected_stored_gpu_block_indexes=(15, 16, 17))
)
runner.run(
decoded_tokens=[0] * offloaded_block_size,
expected_stored_gpu_block_indexes=(15, 16, 17),
)
runner.manager.touch.assert_called()
block_hashes1 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes1) == 6
@@ -426,9 +444,10 @@ def test_offloading_connector(request_runner):
# create a new request differing only on the last token
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
runner.run(decoded_tokens=[0],
expected_stored_gpu_block_indexes=tuple(
range(6 * block_size_factor)))
runner.run(
decoded_tokens=[0],
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
)
runner.manager.touch.assert_called()
block_hashes2 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes2) == 6
@@ -441,17 +460,20 @@ def test_offloading_connector(request_runner):
runner.run(decoded_tokens=[EOS_TOKEN_ID])
# full_block_tokens - num_computed_tokens < offloaded_block_size
runner.new_request(token_ids=[0] * gpu_block_size + [1] *
(offloaded_block_size - gpu_block_size))
runner.manager.prepare_store.side_effect = \
runner.new_request(
token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size)
)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_not_called()
# single block lookup with no hits
runner.new_request(token_ids=[1] * offloaded_block_size)
runner.manager.prepare_store.side_effect = \
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_called()
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
@@ -459,34 +481,37 @@ def test_offloading_connector(request_runner):
# single block lookup with a hit
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = \
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.manager.lookup.return_value = 1
runner.run(decoded_tokens=[EOS_TOKEN_ID],
expected_loaded_gpu_block_indexes=(0, 1, 2))
runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)
)
# single block lookup with a hit in a middle block
runner.new_request(token_ids=[0] * offloaded_block_size * 2 +
[1] * offloaded_block_size)
runner.manager.prepare_store.side_effect = \
runner.new_request(
token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size
)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.manager.lookup.return_value = 1
runner.run(decoded_tokens=[EOS_TOKEN_ID],
expected_loaded_gpu_block_indexes=(3, 4, 5))
runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)
)
# test take_events
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
return [BlockHash(str(i).encode()) for i in int_hashes]
def take_events() -> Iterable[OffloadingEvent]:
yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]),
block_size=16,
medium="A",
removed=False)
yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]),
block_size=32,
medium="B",
removed=True)
yield OffloadingEvent(
block_hashes=to_hashes([1, 2, 3]), block_size=16, medium="A", removed=False
)
yield OffloadingEvent(
block_hashes=to_hashes([4, 5, 6]), block_size=32, medium="B", removed=True
)
runner.manager.take_events.side_effect = take_events
events = list(runner.scheduler_connector.take_events())

View File

@@ -12,22 +12,25 @@ pytestmark = pytest.mark.cpu_test
class DummyModelRunnerOutput(ModelRunnerOutput):
def __init__(self,
finished_sending: Optional[set[str]] = None,
finished_recving: Optional[set[str]] = None,
invalid_block_ids: Optional[set[int]] = None):
def __init__(
self,
finished_sending: Optional[set[str]] = None,
finished_recving: Optional[set[str]] = None,
invalid_block_ids: Optional[set[int]] = None,
):
self.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set())
invalid_block_ids=invalid_block_ids or set(),
)
def __repr__(self):
return (
f"DummyModelRunnerOutput("
f"finished_sending={self.kv_connector_output.finished_sending},"
f"finished_recving={self.kv_connector_output.finished_recving})"
f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})")
f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})"
)
def test_aggregate_workers_output():
@@ -44,8 +47,9 @@ def test_aggregate_workers_output():
assert aggregated.finished_recving is None
assert not aggregated.invalid_block_ids
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, finished_recving={"req2"}
)
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
aggregated = aggregator.aggregate([output1, output2])
@@ -57,26 +61,27 @@ def test_aggregate_workers_output():
assert aggregated.invalid_block_ids == {1}
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
output2 = DummyModelRunnerOutput(finished_sending={'req1'})
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_sending == {"req1"}
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {2}
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
output2 = DummyModelRunnerOutput(finished_recving={'req2'},
invalid_block_ids={4, 5})
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, invalid_block_ids={4, 5}
)
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'}
assert aggregated.finished_recving == {"req2"}
assert aggregated.invalid_block_ids == {3, 4, 5}
@@ -104,8 +109,9 @@ def test_async_aggregate_workers_output():
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, finished_recving={"req2"}
)
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
future1.set_result(output1)
future2.set_result(output2)
@@ -123,7 +129,7 @@ def test_async_aggregate_workers_output():
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
output2 = DummyModelRunnerOutput(finished_sending={'req1'})
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
future1.set_result(output1)
future2.set_result(output2)
@@ -131,7 +137,7 @@ def test_async_aggregate_workers_output():
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_sending == {"req1"}
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {2}
@@ -140,8 +146,9 @@ def test_async_aggregate_workers_output():
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
output2 = DummyModelRunnerOutput(finished_recving={'req2'},
invalid_block_ids={4, 5})
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, invalid_block_ids={4, 5}
)
future1.set_result(output1)
future2.set_result(output2)
@@ -150,5 +157,5 @@ def test_async_aggregate_workers_output():
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'}
assert aggregated.finished_recving == {"req2"}
assert aggregated.invalid_block_ids == {3, 4, 5}

View File

@@ -7,8 +7,13 @@ import pytest
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output,
create_request, create_scheduler, create_vllm_config)
from .utils import (
assert_scheduler_empty,
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
@@ -24,11 +29,13 @@ def test_basic_lifecycle():
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
scheduler.add_request(request)
request_id = request.request_id
@@ -43,8 +50,9 @@ def test_basic_lifecycle():
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
# Ensure the request is finished after 1 token.
assert request.is_finished()
@@ -60,7 +68,8 @@ def test_basic_lifecycle():
# ... but blocks should not be freed.
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
0
].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
@@ -92,7 +101,8 @@ def test_basic_lifecycle():
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending={request_id})
finished_sending={request_id}
)
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
@@ -110,11 +120,13 @@ def test_short_prompt_lifecycle():
# Not enough tokens for full block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_TOKENS = BLOCK_SIZE // 2
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
scheduler.add_request(request)
@@ -132,14 +144,15 @@ def test_short_prompt_lifecycle():
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
assert (len(kv_transfer_params["remote_block_ids"]) == 1)
assert len(kv_transfer_params["remote_block_ids"]) == 1
# Confirm we do not have any memory leaks after req lifecycle.
# We need to mark sending finish to clear data for persistent batch.
scheduler_output = scheduler.schedule()
# Use create_model_runner_output to pass kv_connector_output along
model_runner_output = create_model_runner_output(
reqs=[request], finished_sending={request.request_id})
reqs=[request], finished_sending={request.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)
@@ -155,14 +168,15 @@ def test_prefix_cache_lifecycle():
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS)
request_normal = create_request(
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS
)
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal],
use_eos=True)
model_runner_output = create_model_runner_output(
reqs=[request_normal], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
@@ -174,10 +188,12 @@ def test_prefix_cache_lifecycle():
NUM_EXTERNAL_FULL_BLOCKS -= 1
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
request_remote = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
@@ -187,14 +203,13 @@ def test_prefix_cache_lifecycle():
# Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit.
assert (len(
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
1))
assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending={request_remote.request_id})
finished_sending={request_remote.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)

View File

@@ -7,8 +7,13 @@ import pytest
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output,
create_request, create_scheduler, create_vllm_config)
from .utils import (
assert_scheduler_empty,
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
@@ -24,12 +29,15 @@ def test_basic_lifecycle():
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
START_FREE_BLOCK_QUEUE_SIZE = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
)
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
)
scheduler.add_request(request)
request_id = request.request_id
@@ -48,16 +56,16 @@ def test_basic_lifecycle():
# Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
assert (request.num_computed_tokens == 0)
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
assert (block_pool.free_block_queue.num_free_blocks
< START_FREE_BLOCK_QUEUE_SIZE)
assert block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE
assert len(block_pool.cached_block_hash_to_block) == 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
0
].req_to_blocks[request_id]
for block in blocks:
assert block._block_hash is None
@@ -65,8 +73,9 @@ def test_basic_lifecycle():
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
assert not engine_core_outputs or not engine_core_outputs[0].outputs
# STEP (2):
@@ -78,13 +87,15 @@ def test_basic_lifecycle():
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving={request_id})
finished_recving={request_id}
)
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
assert request_id in scheduler.finished_recving_kv_req_ids
# STEP (3):
# (3a): schedule(): this should actually schedule.
@@ -94,10 +105,11 @@ def test_basic_lifecycle():
# Confirm the block are actually allocated.
num_hashed_blocks = 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
0
].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
num_hashed_blocks += 1 if block._block_hash is not None else 0
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
# Confirm the rest of the prompt is scheduled in this step.
@@ -105,7 +117,7 @@ def test_basic_lifecycle():
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
num_computed_tokens = scheduled_req.num_computed_tokens
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
assert num_scheduled_tokens == total_prompt_tokens - num_computed_tokens
# (3b): execute_model()
model_runner_output = create_model_runner_output([request])
@@ -115,8 +127,9 @@ def test_basic_lifecycle():
# Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
scheduler.schedule()
outputs = engine_core_outputs[0].outputs
@@ -137,10 +150,12 @@ def test_interleaved_lifecycle():
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_remote = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
)
request_local_a = create_request(
request_id=2,
block_size=BLOCK_SIZE,
@@ -169,8 +184,7 @@ def test_interleaved_lifecycle():
assert len(scheduler_output.scheduled_new_reqs) == 1
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b])
model_runner_output = create_model_runner_output([request_local_a, request_local_b])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 3: continue running, KVs not arrived yet.
@@ -181,7 +195,8 @@ def test_interleaved_lifecycle():
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
reqs=[request_local_a, request_local_b])
reqs=[request_local_a, request_local_b]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
@@ -196,8 +211,8 @@ def test_interleaved_lifecycle():
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b],
finished_recving={request_remote.request_id})
[request_local_a, request_local_b], finished_recving={request_remote.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 5: RECVed KVs are sent to ModelRunner.
@@ -208,7 +223,8 @@ def test_interleaved_lifecycle():
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote])
[request_local_a, request_local_b, request_remote]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 6: Hit EOS and free.
@@ -273,15 +289,17 @@ def test_no_spurious_prefix_caching():
assert len(scheduler.waiting) == 1
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_local.request_id]
0
].req_to_blocks[request_local.request_id]
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_remote.request_id]
0
].req_to_blocks[request_remote.request_id]
# Local should have cached blocks (but not all due to preallocate).
num_hashed_blocks = 0
for block in local_blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
num_hashed_blocks += 1 if block._block_hash is not None else 0
assert num_hashed_blocks > 0
# Remote blocks should not be cached.
@@ -301,10 +319,12 @@ def test_full_block_prompt():
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
)
scheduler.add_request(request)
request_id = request.request_id
@@ -312,8 +332,11 @@ def test_full_block_prompt():
# STEP (1): Initialize a recv.
scheduler_output = scheduler.schedule()
# All blocks should be allocated.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
num_blocks = len(
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
]
)
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
scheduler.update_from_output(scheduler_output, model_runner_output)
@@ -322,22 +345,25 @@ def test_full_block_prompt():
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving={request_id})
finished_recving={request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
assert request_id in scheduler.finished_recving_kv_req_ids
# # STEP (3): Run as usual.
scheduler_output = scheduler.schedule()
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
num_blocks = len(
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
]
)
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_TOKENS - 1)
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
assert scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1
assert scheduler_output.num_scheduled_tokens[request_id] == 1
model_runner_output = create_model_runner_output([request])
scheduler.update_from_output(scheduler_output, model_runner_output)
@@ -345,8 +371,9 @@ def test_full_block_prompt():
# # Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
scheduler.schedule()
outputs = engine_core_outputs[0].outputs
@@ -375,13 +402,15 @@ def test_cannot_schedule_after_recv():
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True)
request_normal = create_request(
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
)
request_remote = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True,
)
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
scheduler.add_request(request_normal)
@@ -402,7 +431,8 @@ def test_cannot_schedule_after_recv():
# Step 3: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal], finished_recving={request_remote.request_id})
reqs=[request_normal], finished_recving={request_remote.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
@@ -411,7 +441,8 @@ def test_cannot_schedule_after_recv():
# because the transfer is completed.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal, request_remote])
reqs=[request_normal, request_remote]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 0
@@ -426,8 +457,9 @@ def test_cannot_schedule_after_recv():
# Step 6: finish the request, free it.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal],
use_eos=True)
model_runner_output = create_model_runner_output(
reqs=[request_normal], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
@@ -436,16 +468,19 @@ def test_cannot_schedule_after_recv():
# request is retrieved from preempted list.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
assert (scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] ==
NUM_PROMPT_BLOCKS * BLOCK_SIZE)
assert (
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
== NUM_PROMPT_BLOCKS * BLOCK_SIZE
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Step 8: free everything.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote],
use_eos=True)
model_runner_output = create_model_runner_output(
reqs=[request_remote], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)
@@ -470,13 +505,15 @@ def test_cannot_recv():
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True)
request_normal = create_request(
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
)
request_remote = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True,
)
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
scheduler.add_request(request_normal)
@@ -495,12 +532,13 @@ def test_cannot_recv():
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Should not have KV transfer in progress.
assert (request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS)
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
# Step 3: finish the request, free it.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal],
use_eos=True)
model_runner_output = create_model_runner_output(
reqs=[request_normal], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
@@ -511,12 +549,13 @@ def test_cannot_recv():
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
assert (request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
# Step 5: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[], finished_recving={request_remote.request_id})
reqs=[], finished_recving={request_remote.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
@@ -530,8 +569,9 @@ def test_cannot_recv():
# Step 7: free everything.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote],
use_eos=True)
model_runner_output = create_model_runner_output(
reqs=[request_remote], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@@ -37,16 +37,22 @@ def _list_path(path):
return list(path.iterdir())
def run_test(tmp_path, processor, llm: LLM, question: str,
image_urls: list[Image], expected_len: int, info: str):
def run_test(
tmp_path,
processor,
llm: LLM,
question: str,
image_urls: list[Image],
expected_len: int,
info: str,
):
"""
One individual test to process the prompt and output base on 1 set of input
Then check if the length in the storage path matches the expected length
`info` introduces details or purpose of the individual test
"""
print(f"***info: {info}***")
print(
f"**Expected storage path length after llm generate: {expected_len}**")
print(f"**Expected storage path length after llm generate: {expected_len}**")
process_prompt(processor, llm, question, image_urls)
print(f"Path matched expected length: {_check_path_len(tmp_path)}")
@@ -54,51 +60,42 @@ def run_test(tmp_path, processor, llm: LLM, question: str,
assert _check_path_len(tmp_path) == expected_len, (
f"Expect storage path length {expected_len} ;",
f"but end up {_check_path_len(tmp_path)} instead. ", f"Info: {info}")
f"but end up {_check_path_len(tmp_path)} instead. ",
f"Info: {info}",
)
def process_prompt(processor, llm: LLM, question: str,
image_urls: list[Image]):
def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
"""
Form the prompt based on the text and image input, then llm generate output
"""
placeholders = [{
"type": "image_url",
"image_url": {
"url": f"data:image;base64,{encode_image_base64(image_pil)}"
placeholders = [
{
"type": "image_url",
"image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"},
}
} for image_pil in image_urls]
for image_pil in image_urls
]
messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
{"type": "text", "text": question},
],
},
]
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
outputs = llm.generate(
{
"prompt":
prompt,
**({
"multi_modal_data": {
"image": [*image_urls]
}
} if image_urls else {})
"prompt": prompt,
**({"multi_modal_data": {"image": [*image_urls]}} if image_urls else {}),
},
sampling_params=SAMPLING_PARAMS,
)
@@ -114,7 +111,7 @@ def process_prompt(processor, llm: LLM, question: str,
def test_shared_storage_connector_hashes(tmp_path):
"""
Tests that SharedStorageConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but
with proper hashes; that are unique for inputs with identical text but
different images (same size), or same multiple images but different orders.
"""
# Using tmp_path as the storage path to store KV
@@ -124,7 +121,8 @@ def test_shared_storage_connector_hashes(tmp_path):
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": str(tmp_path)})
kv_connector_extra_config={"shared_storage_path": str(tmp_path)},
)
engine_args = EngineArgs(
model=MODEL_NAME,
@@ -157,56 +155,88 @@ def test_shared_storage_connector_hashes(tmp_path):
# Prepare the input cases
input_cases = [
InputCase(text=TEXT_PROMPTS[0],
img=[image_1],
expected_len=1,
info="image_1 single input the first time."),
InputCase(text=TEXT_PROMPTS[0],
img=[image_2],
expected_len=2,
info=("image_2 single input the first time. "
"It is in same pixel size with image_1, yet it "
"should be able to form a new unique hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[image_1],
expected_len=2,
info=("image_1 single input the 2nd time. "
"It should not form another new hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[image_2],
expected_len=2,
info=("image_2 single input the 2nd time. "
"It should not form another new hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[image_1, image_2],
expected_len=3,
info="image_1 with image_2 input the first time."),
InputCase(text=TEXT_PROMPTS[0],
img=[image_2, image_1],
expected_len=4,
info="The image order is swapped. Should form new hash."),
InputCase(text=TEXT_PROMPTS[0],
img=[image_1, image_2],
expected_len=4,
info=("[image_1, image_2] input the 2nd time. "
"It should not form another new hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[image_2, image_1],
expected_len=4,
info=("[image_2, image_1] input the 2nd time. "
"It should not form another new hash.")),
InputCase(text=TEXT_PROMPTS[0],
img=[],
expected_len=5,
info="Pure text input test as a case-control"),
InputCase(text=TEXT_PROMPTS[0],
img=[],
expected_len=5,
info="Identical pure text input as a case-control"),
InputCase(text=TEXT_PROMPTS[1],
img=[],
expected_len=6,
info="Another pure text input as a case-control"),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_1],
expected_len=1,
info="image_1 single input the first time.",
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_2],
expected_len=2,
info=(
"image_2 single input the first time. "
"It is in same pixel size with image_1, yet it "
"should be able to form a new unique hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_1],
expected_len=2,
info=(
"image_1 single input the 2nd time. "
"It should not form another new hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_2],
expected_len=2,
info=(
"image_2 single input the 2nd time. "
"It should not form another new hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_1, image_2],
expected_len=3,
info="image_1 with image_2 input the first time.",
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_2, image_1],
expected_len=4,
info="The image order is swapped. Should form new hash.",
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_1, image_2],
expected_len=4,
info=(
"[image_1, image_2] input the 2nd time. "
"It should not form another new hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_2, image_1],
expected_len=4,
info=(
"[image_2, image_1] input the 2nd time. "
"It should not form another new hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[],
expected_len=5,
info="Pure text input test as a case-control",
),
InputCase(
text=TEXT_PROMPTS[0],
img=[],
expected_len=5,
info="Identical pure text input as a case-control",
),
InputCase(
text=TEXT_PROMPTS[1],
img=[],
expected_len=6,
info="Another pure text input as a case-control",
),
]
# Run tests

View File

@@ -8,19 +8,27 @@ from typing import Any, Callable, Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.config import (
CacheConfig,
DeviceConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
SharedStorageConnector,
)
from vllm.utils import sha256
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
@@ -42,14 +50,24 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
assert (
len(
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
)
== 0
)
assert (
len(
scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].num_cached_block
)
== 0
)
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
)
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
@@ -90,11 +108,13 @@ def create_vllm_config(
kv_connector="NixlConnector",
kv_role="kv_both",
)
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"))
return VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"),
)
def create_scheduler(
@@ -107,9 +127,9 @@ def create_scheduler(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
KVCacheGroupSpec(
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
)
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
@@ -151,16 +171,16 @@ def create_request(
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True)
kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = dict(do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234)
kv_transfer_params = dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
@@ -200,13 +220,19 @@ def create_model_runner_output(
sampled_token = EOS_TOKEN_ID if use_eos else token_id
sampled_token_ids = [[sampled_token] for _ in req_ids]
kv_connector_output = None if (
finished_sending is None and finished_recving is None
and invalid_block_ids is None) else KVConnectorOutput(
kv_connector_output = (
None
if (
finished_sending is None
and finished_recving is None
and invalid_block_ids is None
)
else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set(),
)
)
# Make output data structure.
return ModelRunnerOutput(
@@ -221,22 +247,30 @@ def create_model_runner_output(
class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = tempfile.gettempdir(
) + f"/connector_{self.name}-{self.role.name}_events.log"
self._event_file = (
tempfile.gettempdir()
+ f"/connector_{self.name}-{self.role.name}_events.log"
)
# Start with an empty file
with open(self._event_file, "w") as _:
pass
def __getattribute__(self, name):
if name in ("_connector", "call_record", "name", "_event_file",
"__class__", "__dict__", "__getattribute__",
"__init__"): # avoid recursion
if name in (
"_connector",
"call_record",
"name",
"_event_file",
"__class__",
"__dict__",
"__getattribute__",
"__init__",
): # avoid recursion
return object.__getattribute__(self, name)
if not hasattr(self._connector, name):
return object.__getattribute__(self, name)
@@ -255,21 +289,20 @@ class TestSharedStorageConnector(SharedStorageConnector):
if isinstance(arg, int):
to_log.append(str(arg))
elif isinstance(arg, KVCacheBlocks):
to_log.append(
f"num_blocks={[len(b) for b in arg.blocks]}")
to_log.append(f"num_blocks={[len(b) for b in arg.blocks]}")
# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(' '.join(to_log) + "\n")
f.write(" ".join(to_log) + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} "
f"for {self.name}: {e}")
print(f"[ERROR] Could not log event {name} for {self.name}: {e}")
return attr(*args, **kwargs)
return wrapper
return attr
KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__,
TestSharedStorageConnector.__name__)
KVConnectorFactory.register_connector(
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
)