Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,12 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501
|
||||
SharedStorageConnectorMetadata)
|
||||
SharedStorageConnectorMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
ensure_kv_transfer_initialized, get_kv_transfer_group)
|
||||
ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
)
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
|
||||
# Importing utils registers TestSharedStorageConnector with the factory
|
||||
from .utils import create_vllm_config
|
||||
@@ -34,7 +36,7 @@ def test_kv_connector_mixin_clears_metadata():
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector"
|
||||
vllm_config.kv_transfer_config.kv_role = "kv_both"
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = ("unit")
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
|
||||
|
||||
# Initialize the global connector instance
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
@@ -46,7 +48,8 @@ def test_kv_connector_mixin_clears_metadata():
|
||||
|
||||
# Invoke the no-forward path which uses the mixin context manager
|
||||
KVConnectorModelRunnerMixin.kv_connector_no_forward(
|
||||
scheduler_output, vllm_config)
|
||||
scheduler_output, vllm_config
|
||||
)
|
||||
|
||||
# Verify clear_connector_metadata was called on the connector
|
||||
connector = get_kv_transfer_group()
|
||||
|
||||
@@ -9,17 +9,19 @@ import pytest
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from .utils import (create_model_runner_output, create_request,
|
||||
create_scheduler, create_vllm_config)
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
|
||||
def get_num_new_matched_tokens(request: Request,
|
||||
_: int) -> tuple[int, bool]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
@@ -33,9 +35,7 @@ def scheduler():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,"
|
||||
"num_external_computed_blocks,"
|
||||
"invalid_block_idxs",
|
||||
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 98}),
|
||||
(100, 99, {50, 98}),
|
||||
@@ -51,8 +51,7 @@ def test_async_load_failure(
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = (num_external_computed_blocks *
|
||||
scheduler.block_size)
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request1)
|
||||
@@ -71,8 +70,8 @@ def test_async_load_failure(
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
|
||||
async_load=True))
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
|
||||
)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
@@ -84,14 +83,14 @@ def test_async_load_failure(
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
# Simulate a failure in loading some of request2 blocks.
|
||||
(req2_block_ids, ) = scheduler.kv_cache_manager.get_block_ids(
|
||||
request2.request_id)
|
||||
(req2_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request2.request_id)
|
||||
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving={request1.request_id, request3.request_id},
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True)
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
@@ -100,8 +99,9 @@ def test_async_load_failure(
|
||||
assert len(scheduler.waiting) == 3
|
||||
for request in scheduler.waiting:
|
||||
if request.request_id == request2.request_id:
|
||||
assert request.num_computed_tokens == (min_invalid_block_idx *
|
||||
scheduler.block_size)
|
||||
assert request.num_computed_tokens == (
|
||||
min_invalid_block_idx * scheduler.block_size
|
||||
)
|
||||
else:
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
@@ -110,9 +110,7 @@ def test_async_load_failure(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,"
|
||||
"num_external_computed_blocks,"
|
||||
"invalid_block_idxs",
|
||||
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 98}),
|
||||
(100, 99, {50, 98}),
|
||||
@@ -128,8 +126,7 @@ def test_sync_load_failure(
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = (num_external_computed_blocks *
|
||||
scheduler.block_size)
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request1)
|
||||
@@ -148,8 +145,8 @@ def test_sync_load_failure(
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
|
||||
async_load=False))
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
|
||||
)
|
||||
scheduler.connector.request_finished.return_value = (False, None)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
@@ -165,8 +162,7 @@ def test_sync_load_failure(
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 3
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[
|
||||
request.req_id]
|
||||
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
# Simulate a failure in loading some of request2 blocks.
|
||||
@@ -175,14 +171,16 @@ def test_sync_load_failure(
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1, request2, request3],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True)
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == request2.request_id
|
||||
assert scheduler.running[0].num_computed_tokens == (
|
||||
min(invalid_block_idxs) * scheduler.block_size)
|
||||
min(invalid_block_idxs) * scheduler.block_size
|
||||
)
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
assert scheduler.connector.request_finished.call_count == 2
|
||||
|
||||
@@ -205,19 +203,19 @@ def test_sync_load_failure_with_shared_blocks(
|
||||
num_common_prefix_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert (num_prompt_blocks >= num_external_computed_blocks >=
|
||||
num_common_prefix_blocks)
|
||||
assert num_prompt_blocks >= num_external_computed_blocks >= num_common_prefix_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = (num_external_computed_blocks *
|
||||
scheduler.block_size)
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
common_prefix_len = num_common_prefix_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens,
|
||||
common_prefix_len=common_prefix_len)
|
||||
request1 = create_request(
|
||||
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
|
||||
)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(num_tokens=num_prompt_tokens,
|
||||
common_prefix_len=common_prefix_len)
|
||||
request2 = create_request(
|
||||
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
|
||||
)
|
||||
scheduler.add_request(request=request2)
|
||||
|
||||
# Mock KV connector method.
|
||||
@@ -228,8 +226,8 @@ def test_sync_load_failure_with_shared_blocks(
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
|
||||
async_load=False))
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
|
||||
)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
@@ -243,17 +241,15 @@ def test_sync_load_failure_with_shared_blocks(
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 2
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[
|
||||
request.req_id]
|
||||
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
|
||||
|
||||
# Simulate a failure in loading some of the shared blocks.
|
||||
req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1, request2],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True)
|
||||
[request1, request2], invalid_block_ids=invalid_block_ids, use_eos=True
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
@@ -266,15 +262,14 @@ def test_sync_load_failure_with_shared_blocks(
|
||||
|
||||
assert len(scheduler.running) == 2
|
||||
for request in scheduler.running:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[
|
||||
request.request_id]
|
||||
assert (
|
||||
request.num_computed_tokens == expected_computed_tokens[request.request_id]
|
||||
)
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,"
|
||||
"num_external_computed_blocks,"
|
||||
"invalid_block_idxs",
|
||||
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 50, 98}),
|
||||
(100, 99, {98, 50, 0}),
|
||||
@@ -289,8 +284,7 @@ def test_async_progressive_load_failure(
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = (num_external_computed_blocks *
|
||||
scheduler.block_size)
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request)
|
||||
@@ -303,8 +297,8 @@ def test_async_progressive_load_failure(
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
|
||||
async_load=True))
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
|
||||
)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
@@ -318,24 +312,24 @@ def test_async_progressive_load_failure(
|
||||
min_invalid_block_idx = max(invalid_block_idxs) + 1
|
||||
# Simulate failures when progressively loading request blocks.
|
||||
for invalid_block_idx in invalid_block_idxs:
|
||||
(req_block_ids, ) = scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id)
|
||||
(req_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=set(),
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True)
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
|
||||
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert scheduler.waiting.peek_request(
|
||||
).request_id == request.request_id
|
||||
assert request.num_computed_tokens == (min_invalid_block_idx *
|
||||
scheduler.block_size)
|
||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
||||
assert request.num_computed_tokens == (
|
||||
min_invalid_block_idx * scheduler.block_size
|
||||
)
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.failed_recving_kv_req_ids == {request.request_id}
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
|
||||
@@ -52,29 +52,26 @@ def test_multi_shared_storage_connector_consistency():
|
||||
kv_connector="MultiConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"connectors": [{
|
||||
"kv_connector":
|
||||
"TestSharedStorageConnector",
|
||||
"kv_role":
|
||||
"kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_1_path),
|
||||
"name": "storage1",
|
||||
"connectors": [
|
||||
{
|
||||
"kv_connector": "TestSharedStorageConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_1_path),
|
||||
"name": "storage1",
|
||||
},
|
||||
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
|
||||
},
|
||||
"kv_connector_module_path":
|
||||
"tests.v1.kv_connector.unit.utils",
|
||||
}, {
|
||||
"kv_connector":
|
||||
"TestSharedStorageConnector",
|
||||
"kv_role":
|
||||
"kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_2_path),
|
||||
"name": "storage2",
|
||||
{
|
||||
"kv_connector": "TestSharedStorageConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_2_path),
|
||||
"name": "storage2",
|
||||
},
|
||||
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
|
||||
},
|
||||
"kv_connector_module_path":
|
||||
"tests.v1.kv_connector.unit.utils",
|
||||
}]
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@@ -93,14 +90,16 @@ def test_multi_shared_storage_connector_consistency():
|
||||
local_subdirs = list(storage_1_path.iterdir())
|
||||
external_subdirs = list(storage_2_path.iterdir())
|
||||
|
||||
assert len(
|
||||
local_subdirs
|
||||
) > 0, f"Local storage path {storage_1_path} is empty after generation."
|
||||
assert len(local_subdirs) > 0, (
|
||||
f"Local storage path {storage_1_path} is empty after generation."
|
||||
)
|
||||
assert len(external_subdirs) > 0, (
|
||||
f"External storage path {storage_2_path} is empty after generation.")
|
||||
f"External storage path {storage_2_path} is empty after generation."
|
||||
)
|
||||
assert len(local_subdirs) == len(external_subdirs), (
|
||||
f"Mismatch in number of cache entries: "
|
||||
f"Local={len(local_subdirs)}, External={len(external_subdirs)}")
|
||||
f"Local={len(local_subdirs)}, External={len(external_subdirs)}"
|
||||
)
|
||||
|
||||
# The subdirectories should correspond to the prompt hashes
|
||||
# Since prompts are the same, the hash directories should be the same name
|
||||
@@ -113,29 +112,39 @@ def test_multi_shared_storage_connector_consistency():
|
||||
# Compare the contents of each corresponding cache directory
|
||||
for subdir_name in local_subdir_names:
|
||||
print(f"Comparing contents of cache directory: {subdir_name}")
|
||||
assert _compare_directories(storage_1_path / subdir_name,
|
||||
storage_2_path / subdir_name), \
|
||||
(f"Contents differ for cache directory '{subdir_name}' between "
|
||||
f"{storage_1_path} and {storage_2_path}")
|
||||
assert _compare_directories(
|
||||
storage_1_path / subdir_name, storage_2_path / subdir_name
|
||||
), (
|
||||
f"Contents differ for cache directory '{subdir_name}' between "
|
||||
f"{storage_1_path} and {storage_2_path}"
|
||||
)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens and update_state_after_alloc will be called
|
||||
# on each connector in turn.
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta'
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage1-WORKER"][:5] == [
|
||||
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
|
||||
'wait_for_layer_load', 'save_kv_layer'
|
||||
"register_kv_caches",
|
||||
"bind_connector_metadata",
|
||||
"start_load_kv",
|
||||
"wait_for_layer_load",
|
||||
"save_kv_layer",
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta'
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage2-WORKER"][:5] == [
|
||||
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
|
||||
'wait_for_layer_load', 'save_kv_layer'
|
||||
"register_kv_caches",
|
||||
"bind_connector_metadata",
|
||||
"start_load_kv",
|
||||
"wait_for_layer_load",
|
||||
"save_kv_layer",
|
||||
]
|
||||
|
||||
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||
@@ -151,12 +160,14 @@ def test_multi_shared_storage_connector_consistency():
|
||||
# on that one but with zero blocks for others (first nonzero match is
|
||||
# chosen).
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=[7] 96', 'build_connector_meta'
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[7] 96",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta'
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
|
||||
# Delete storage1 connector state
|
||||
@@ -175,12 +186,14 @@ def test_multi_shared_storage_connector_consistency():
|
||||
# a hit, so update_state_after_alloc will only be called with allocated
|
||||
# blocks for the second connector.
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta'
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=[7] 96', 'build_connector_meta'
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[7] 96",
|
||||
"build_connector_meta",
|
||||
]
|
||||
|
||||
# Clean up
|
||||
@@ -191,15 +204,14 @@ def test_multi_shared_storage_connector_consistency():
|
||||
def get_connector_events() -> dict[str, list[str]]:
|
||||
# Read in connector events and reset the files.
|
||||
import glob
|
||||
|
||||
event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log")
|
||||
connector_events = {}
|
||||
for fname in event_files:
|
||||
name = fname.split("connector_")[1].split("_events.log")[0]
|
||||
try:
|
||||
with open(fname, "r+") as f:
|
||||
connector_events[name] = [
|
||||
line.strip() for line in f if line.strip()
|
||||
]
|
||||
connector_events[name] = [line.strip() for line in f if line.strip()]
|
||||
f.truncate(0)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not read connector events for {name}: {e}")
|
||||
@@ -211,5 +223,5 @@ def test_engine_id_conflict():
|
||||
configs = [KVTransferConfig() for _ in range(2)]
|
||||
ids = [config.engine_id for config in configs]
|
||||
assert ids[0] != ids[1], (
|
||||
"Engine IDs should be different for different configs. "
|
||||
f"Got {ids}")
|
||||
f"Engine IDs should be different for different configs. Got {ids}"
|
||||
)
|
||||
|
||||
@@ -19,15 +19,22 @@ import torch
|
||||
from vllm import LLM
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||
MultiKVConnectorStats)
|
||||
MultiKVConnectorStats,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker, NixlKVConnectorStats)
|
||||
KVConnectorRole,
|
||||
NixlAgentMetadata,
|
||||
NixlConnector,
|
||||
NixlConnectorMetadata,
|
||||
NixlConnectorWorker,
|
||||
NixlKVConnectorStats,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
ensure_kv_transfer_shutdown, has_kv_transfer_group)
|
||||
ensure_kv_transfer_shutdown,
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.platforms.interface import Platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -42,14 +49,14 @@ from .utils import create_request, create_scheduler, create_vllm_config
|
||||
def clear_kv_transfer():
|
||||
"""
|
||||
The test cases in this file use `VLLM_ENABLE_V1_MULTIPROCESSING=0`,
|
||||
causing the global variable `_KV_CONNECTOR_AGENT`
|
||||
causing the global variable `_KV_CONNECTOR_AGENT`
|
||||
to be assigned but never deleted.
|
||||
|
||||
Since the current pytest process does not terminate and instead
|
||||
Since the current pytest process does not terminate and instead
|
||||
continues running tests from other files,
|
||||
this global variable remains in memory and interferes
|
||||
this global variable remains in memory and interferes
|
||||
with test cases in other modules.
|
||||
|
||||
|
||||
So we use this fixture to ensure that the global variable
|
||||
`_KV_CONNECTOR_AGENT` is properly cleaned up after each test.
|
||||
"""
|
||||
@@ -58,11 +65,12 @@ def clear_kv_transfer():
|
||||
ensure_kv_transfer_shutdown()
|
||||
|
||||
|
||||
def get_default_xfer_telemetry(xferDurationS: float = 1,
|
||||
postDurationS: float = 1,
|
||||
totalBytes: int = 1,
|
||||
descCount: int = 1) -> dict:
|
||||
|
||||
def get_default_xfer_telemetry(
|
||||
xferDurationS: float = 1,
|
||||
postDurationS: float = 1,
|
||||
totalBytes: int = 1,
|
||||
descCount: int = 1,
|
||||
) -> dict:
|
||||
class AttributeDict(dict):
|
||||
__slots__ = ()
|
||||
__getattr__ = dict.__getitem__
|
||||
@@ -83,7 +91,7 @@ class FakeNixlWrapper:
|
||||
|
||||
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
||||
installed.
|
||||
|
||||
|
||||
Note: The complete source of this class is also used in the
|
||||
`_make_fake_nixl_pkg` function to create a fake nixl package
|
||||
for Ray workers.
|
||||
@@ -94,8 +102,7 @@ class FakeNixlWrapper:
|
||||
|
||||
def __init__(self, agent_name: str, *args, **kwargs):
|
||||
self._cycles_before_xfer_done = 0
|
||||
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
||||
lambda: 0)
|
||||
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(lambda: 0)
|
||||
|
||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in caches_data]
|
||||
@@ -123,8 +130,7 @@ class FakeNixlWrapper:
|
||||
return {}
|
||||
|
||||
def check_xfer_state(self, handle: int) -> str:
|
||||
if self._check_xfer_state_cycles[
|
||||
handle] >= self._cycles_before_xfer_done:
|
||||
if self._check_xfer_state_cycles[handle] >= self._cycles_before_xfer_done:
|
||||
return "DONE"
|
||||
self._check_xfer_state_cycles[handle] += 1
|
||||
return "PROC"
|
||||
@@ -141,13 +147,15 @@ class FakeNixlWrapper:
|
||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||
pass
|
||||
|
||||
def make_prepped_xfer(self,
|
||||
xfer_type: str,
|
||||
local_xfer_side_handle: int,
|
||||
local_block_descs_ids: list[int],
|
||||
remote_xfer_side_handle: int,
|
||||
remote_block_descs_ids: list[int],
|
||||
notif_msg: Optional[bytes] = None) -> int:
|
||||
def make_prepped_xfer(
|
||||
self,
|
||||
xfer_type: str,
|
||||
local_xfer_side_handle: int,
|
||||
local_block_descs_ids: list[int],
|
||||
remote_xfer_side_handle: int,
|
||||
remote_block_descs_ids: list[int],
|
||||
notif_msg: Optional[bytes] = None,
|
||||
) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def transfer(self, handle: int) -> str:
|
||||
@@ -168,7 +176,7 @@ class FakeNixlWrapper:
|
||||
def _make_fake_nixl_pkg():
|
||||
"""Context manager that creates a temporary package making
|
||||
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
|
||||
|
||||
|
||||
Automatically cleans up the temporary directory when done.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
@@ -214,10 +222,12 @@ def test_basic_interface():
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
@@ -233,8 +243,11 @@ def test_basic_interface():
|
||||
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id]):
|
||||
req_meta.local_block_ids,
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id
|
||||
|
||||
|
||||
@@ -254,11 +267,13 @@ def test_prompt_less_than_block_size():
|
||||
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
|
||||
|
||||
# Request will have 1 partial remote block.
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
num_remote_blocks=1)
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
num_remote_blocks=1,
|
||||
)
|
||||
scheduler.add_request(request)
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
@@ -271,15 +286,15 @@ def test_prompt_less_than_block_size():
|
||||
|
||||
|
||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
|
||||
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
|
||||
def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
|
||||
expected_engine_id: str) -> dict[int, str]:
|
||||
def _nixl_handshake(
|
||||
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
|
||||
) -> dict[int, str]:
|
||||
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
|
||||
time.sleep(self._hand_shake_latency)
|
||||
# These should've been done in register_kv_caches(), called by
|
||||
@@ -304,21 +319,23 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
# is started. We mock HND here.
|
||||
kv_cache_layout="HND",
|
||||
),
|
||||
remote_tp_size=remote_tp_size)
|
||||
remote_tp_size=remote_tp_size,
|
||||
)
|
||||
return {0: remote_agent_name}
|
||||
|
||||
|
||||
class TestNixlHandshake:
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_multi_xfer_one_engine(
|
||||
self,
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init):
|
||||
dist_init,
|
||||
):
|
||||
"""Test case where multiple xfers are initiated to the same engine.
|
||||
|
||||
|
||||
This test triggers the connector to load remote KV for the same
|
||||
`request_id`. The transfer is not done immediately due to
|
||||
`set_cycles_before_xfer_done`, so there is a state where there are
|
||||
@@ -332,9 +349,9 @@ class TestNixlHandshake:
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0)
|
||||
assert isinstance(connector.connector_worker.nixl_wrapper,
|
||||
FakeNixlWrapper)
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
||||
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||
num_xfers = 4
|
||||
while True:
|
||||
@@ -345,21 +362,19 @@ class TestNixlHandshake:
|
||||
num_xfers -= 1
|
||||
metadata.add_new_req(
|
||||
request_id=request_id,
|
||||
local_block_ids=[
|
||||
num_xfers + 1, num_xfers + 2, num_xfers + 3
|
||||
],
|
||||
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids":
|
||||
[num_xfers + 4, num_xfers + 5, num_xfers + 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host":
|
||||
"localhost",
|
||||
"remote_port":
|
||||
1234,
|
||||
"remote_tp_size":
|
||||
1,
|
||||
})
|
||||
"remote_block_ids": [
|
||||
num_xfers + 4,
|
||||
num_xfers + 5,
|
||||
num_xfers + 6,
|
||||
],
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
},
|
||||
)
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
# Mimic maybe_setup_kv_connector in gpu_model_runner.
|
||||
@@ -371,8 +386,9 @@ class TestNixlHandshake:
|
||||
_before_load = time.perf_counter()
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_after_load = time.perf_counter()
|
||||
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||
f"{_after_load - _before_load} seconds"
|
||||
assert _after_load - _before_load < 0.1, (
|
||||
f"start_load_kv took {_after_load - _before_load} seconds"
|
||||
)
|
||||
|
||||
# Mimic get_finished_kv_transfers in gpu_model_runner.
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
@@ -384,20 +400,25 @@ class TestNixlHandshake:
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
@pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [
|
||||
(1, 1),
|
||||
(2, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
])
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"decode_tp_size, prefill_tp_size",
|
||||
[
|
||||
(1, 1),
|
||||
(2, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
],
|
||||
)
|
||||
def test_async_load_kv(
|
||||
self,
|
||||
# Fixture that initializes the distributed environment.
|
||||
dist_init,
|
||||
# Simulate consumer-producer TP sizes.
|
||||
decode_tp_size,
|
||||
prefill_tp_size):
|
||||
self,
|
||||
# Fixture that initializes the distributed environment.
|
||||
dist_init,
|
||||
# Simulate consumer-producer TP sizes.
|
||||
decode_tp_size,
|
||||
prefill_tp_size,
|
||||
):
|
||||
"""Test that NixlConnector's start_load_kv should be non-blocking."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
@@ -406,18 +427,20 @@ class TestNixlHandshake:
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id)
|
||||
vllm_config, connector.engine_id
|
||||
)
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(request_id="id",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": prefill_tp_size,
|
||||
})
|
||||
metadata.add_new_req(
|
||||
request_id="id",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": prefill_tp_size,
|
||||
},
|
||||
)
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
timeout = 2.5
|
||||
@@ -431,8 +454,9 @@ class TestNixlHandshake:
|
||||
_before_load = time.perf_counter()
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_after_load = time.perf_counter()
|
||||
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||
f"{_after_load - _before_load} seconds"
|
||||
assert _after_load - _before_load < 0.1, (
|
||||
f"start_load_kv took {_after_load - _before_load} seconds"
|
||||
)
|
||||
time.sleep(0.5) # backoff for the async handshake to complete.
|
||||
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
@@ -442,11 +466,13 @@ class TestNixlHandshake:
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_concurrent_load_kv(
|
||||
self,
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init):
|
||||
dist_init,
|
||||
):
|
||||
"""Test that multiple start_load_kv calls should occur concurrently."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
@@ -454,20 +480,22 @@ class TestNixlHandshake:
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id)
|
||||
vllm_config, connector.engine_id
|
||||
)
|
||||
metadata = NixlConnectorMetadata()
|
||||
total_reqs = 5
|
||||
for i in range(total_reqs):
|
||||
metadata.add_new_req(request_id=f"id_{i}",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
})
|
||||
metadata.add_new_req(
|
||||
request_id=f"id_{i}",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
},
|
||||
)
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
timeout = 2.5 * total_reqs
|
||||
@@ -482,8 +510,9 @@ class TestNixlHandshake:
|
||||
_before_load = time.perf_counter()
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_after_load = time.perf_counter()
|
||||
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||
f"{_after_load - _before_load} seconds"
|
||||
assert _after_load - _before_load < 0.1, (
|
||||
f"start_load_kv took {_after_load - _before_load} seconds"
|
||||
)
|
||||
time.sleep(0.5) # backoff for the async handshake to complete.
|
||||
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
@@ -495,7 +524,8 @@ class TestNixlHandshake:
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
|
||||
"""
|
||||
Verify that adding a remote agent fails if kv_cache_layout differs.
|
||||
@@ -506,12 +536,14 @@ class TestNixlHandshake:
|
||||
# Mock TP world size to 2 to force heterogeneous TP when
|
||||
# remote_tp_size=1
|
||||
with patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
|
||||
return_value=2):
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
|
||||
return_value=2,
|
||||
):
|
||||
# Initialize connector and worker (with fake NIXL wrapper)
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0)
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
@@ -521,8 +553,7 @@ class TestNixlHandshake:
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
|
||||
# Metadata with different kv_cache_layout than local worker
|
||||
mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \
|
||||
else "NHD"
|
||||
mismatched_layout = "HND" if worker.kv_cache_layout != "HND" else "NHD"
|
||||
meta = NixlAgentMetadata(
|
||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
@@ -545,16 +576,17 @@ class TestNixlHandshake:
|
||||
# the rest of the tests.
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_kv_connector_stats(dist_init):
|
||||
"""Test that KV transfer stats are properly recorded and retrieved."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=0)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
|
||||
# Verify that xfer_stats starts empty
|
||||
initial_stats = connector.get_kv_connector_stats()
|
||||
@@ -563,16 +595,17 @@ def test_kv_connector_stats(dist_init):
|
||||
# Create transfer metadata
|
||||
request_id = "test_req_for_stats"
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(request_id=request_id,
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
})
|
||||
metadata.add_new_req(
|
||||
request_id=request_id,
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
},
|
||||
)
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
# Start the transfer
|
||||
@@ -593,8 +626,7 @@ def test_kv_connector_stats(dist_init):
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
if len(done_recving) > 0 and request_id in done_recving:
|
||||
break
|
||||
time.sleep(
|
||||
0.1) # Small delay to allow background handshake to complete
|
||||
time.sleep(0.1) # Small delay to allow background handshake to complete
|
||||
else:
|
||||
assert "Transfer did not complete within expected iterations"
|
||||
|
||||
@@ -613,7 +645,7 @@ def test_kv_connector_stats(dist_init):
|
||||
|
||||
def test_kv_connector_stats_aggregation():
|
||||
"""
|
||||
Test KV transfer stats aggregation across TP ranks using
|
||||
Test KV transfer stats aggregation across TP ranks using
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
|
||||
@@ -636,18 +668,16 @@ def test_kv_connector_stats_aggregation():
|
||||
worker2_stats.record_transfer(stats)
|
||||
|
||||
# Worker 3: 3 transfers
|
||||
stats = get_default_xfer_telemetry(xferDurationS=2,
|
||||
postDurationS=2,
|
||||
totalBytes=2,
|
||||
descCount=2)
|
||||
stats = get_default_xfer_telemetry(
|
||||
xferDurationS=2, postDurationS=2, totalBytes=2, descCount=2
|
||||
)
|
||||
worker3_stats.record_transfer(stats)
|
||||
worker3_stats.record_transfer(stats)
|
||||
worker3_stats.record_transfer(stats)
|
||||
|
||||
# Create ModelRunnerOutput instances for each worker
|
||||
worker_outputs = []
|
||||
for i, worker_stats in enumerate(
|
||||
[worker1_stats, worker2_stats, worker3_stats]):
|
||||
for i, worker_stats in enumerate([worker1_stats, worker2_stats, worker3_stats]):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=[f"req_{i}"],
|
||||
req_id_to_index={f"req_{i}": 0},
|
||||
@@ -657,17 +687,19 @@ def test_kv_connector_stats_aggregation():
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=set([f"req_{i}_send"])
|
||||
if i < 2 else None, # Workers 0,1 finished sending
|
||||
if i < 2
|
||||
else None, # Workers 0,1 finished sending
|
||||
finished_recving=set([f"req_{i}_recv"])
|
||||
if i > 0 else None, # Workers 1,2 finished receiving
|
||||
if i > 0
|
||||
else None, # Workers 1,2 finished receiving
|
||||
kv_connector_stats=worker_stats,
|
||||
))
|
||||
),
|
||||
)
|
||||
worker_outputs.append(output)
|
||||
|
||||
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_connector_stats = \
|
||||
aggregated_output.kv_connector_output.kv_connector_stats
|
||||
kv_connector_stats = aggregated_output.kv_connector_output.kv_connector_stats
|
||||
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
|
||||
# Number of total transfers across all workers.
|
||||
assert kv_connector_stats.num_successful_transfers == 6
|
||||
@@ -691,7 +723,6 @@ def test_multi_kv_connector_stats_aggregation():
|
||||
# Mock a KVConnectorStats class for testing aggregation over connectors.
|
||||
@dataclass
|
||||
class FooKVConnectorStats(KVConnectorStats):
|
||||
|
||||
def reset(self):
|
||||
self.data = {"num_foo_transfers": 0}
|
||||
|
||||
@@ -703,15 +734,12 @@ def test_multi_kv_connector_stats_aggregation():
|
||||
def is_empty(self) -> bool:
|
||||
return self.data["num_foo_transfers"] == 0
|
||||
|
||||
def aggregate(self,
|
||||
other: "FooKVConnectorStats") -> "FooKVConnectorStats":
|
||||
def aggregate(self, other: "FooKVConnectorStats") -> "FooKVConnectorStats":
|
||||
if not other.is_empty():
|
||||
self.data["num_foo_transfers"] += other.data[
|
||||
"num_foo_transfers"]
|
||||
self.data["num_foo_transfers"] += other.data["num_foo_transfers"]
|
||||
return self
|
||||
|
||||
def make_multi_stats(nixl_count: int,
|
||||
foo_count: int) -> MultiKVConnectorStats:
|
||||
def make_multi_stats(nixl_count: int, foo_count: int) -> MultiKVConnectorStats:
|
||||
data: dict[str, KVConnectorStats] = {}
|
||||
if nixl_count > 0:
|
||||
nixl_stats = NixlKVConnectorStats()
|
||||
@@ -747,13 +775,11 @@ def test_multi_kv_connector_stats_aggregation():
|
||||
worker_outputs.append(output)
|
||||
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_connector_stats = \
|
||||
aggregated_output.kv_connector_output.kv_connector_stats
|
||||
kv_connector_stats = aggregated_output.kv_connector_output.kv_connector_stats
|
||||
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
|
||||
|
||||
# Validate per-connector totals across workers
|
||||
assert isinstance(kv_connector_stats["NixlConnector"],
|
||||
NixlKVConnectorStats)
|
||||
assert isinstance(kv_connector_stats["NixlConnector"], NixlKVConnectorStats)
|
||||
assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5
|
||||
assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats)
|
||||
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
|
||||
@@ -762,11 +788,12 @@ def test_multi_kv_connector_stats_aggregation():
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
||||
"""
|
||||
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
|
||||
-----> P
|
||||
-----> P
|
||||
| {process request}
|
||||
<-/--- | {result is NOT delivered, eg proxy is down}
|
||||
|
|
||||
@@ -823,39 +850,38 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=1,
|
||||
extra_args={"kv_transfer_params": remote_prefill_opts})
|
||||
extra_args={"kv_transfer_params": remote_prefill_opts},
|
||||
)
|
||||
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
|
||||
req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks
|
||||
0
|
||||
].req_to_blocks
|
||||
|
||||
padding = "Just making this request a little longer so that we're sure "
|
||||
"we're not hitting the small-request lower bound beneath which we don't "
|
||||
"actually trigger the whole kv transfer, but rather just recompute the "
|
||||
"blocks on D."
|
||||
_ = llm.generate([f"What is the capital of Japan? {padding}"],
|
||||
sampling_params)
|
||||
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
|
||||
|
||||
# Request finished but not freed
|
||||
assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks
|
||||
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks
|
||||
# Some other request, 0 still not freed
|
||||
_ = llm.generate([f"What is the capital of Italy? {padding}"],
|
||||
sampling_params)
|
||||
assert '0' in req_to_blocks
|
||||
assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks
|
||||
_ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
|
||||
assert "0" in req_to_blocks
|
||||
assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks
|
||||
|
||||
# Wait for timeout and trigger another scheduler loop
|
||||
time.sleep(timeout)
|
||||
_ = llm.generate([f"What is the capital of France? {padding}"],
|
||||
sampling_params)
|
||||
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
||||
# Request-0 times out and is cleared!
|
||||
assert '0' not in req_to_blocks
|
||||
assert "0" not in req_to_blocks
|
||||
|
||||
|
||||
def test_register_kv_caches(dist_init):
|
||||
"""
|
||||
Test that register_kv_caches() properly calls nixl_wrapper methods with
|
||||
correct data.
|
||||
|
||||
|
||||
This test verifies:
|
||||
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
|
||||
tensor metadata
|
||||
@@ -866,10 +892,9 @@ def test_register_kv_caches(dist_init):
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Create test kv cache tensors using proper backend shape
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2,
|
||||
block_size=16,
|
||||
num_kv_heads=4,
|
||||
head_size=64)
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
@@ -879,21 +904,30 @@ def test_register_kv_caches(dist_init):
|
||||
}
|
||||
|
||||
# Store tensor info for validation
|
||||
expected_tensor_size = shared_tensor[0].element_size(
|
||||
) * shared_tensor[0].numel()
|
||||
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
|
||||
expected_base_addrs = [
|
||||
shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(),
|
||||
unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr()
|
||||
shared_tensor[0].data_ptr(),
|
||||
shared_tensor[1].data_ptr(),
|
||||
unique_tensor[0].data_ptr(),
|
||||
unique_tensor[1].data_ptr(),
|
||||
]
|
||||
|
||||
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
|
||||
) as mock_nixl_wrapper,
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
|
||||
),
|
||||
): # noqa: E501
|
||||
# Create connector
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0)
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
|
||||
# Get the mock instance
|
||||
mock_wrapper_instance = mock_nixl_wrapper.return_value
|
||||
@@ -909,12 +943,13 @@ def test_register_kv_caches(dist_init):
|
||||
|
||||
for i, cache_entry in enumerate(caches_data):
|
||||
base_addr, size, _tp_rank, _ = cache_entry
|
||||
assert size == expected_tensor_size, \
|
||||
f"Entry {i}: Expected tensor size {expected_tensor_size}, " \
|
||||
f"got {size}"
|
||||
assert base_addr == expected_base_addrs[i], \
|
||||
f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \
|
||||
assert size == expected_tensor_size, (
|
||||
f"Entry {i}: Expected tensor size {expected_tensor_size}, got {size}"
|
||||
)
|
||||
assert base_addr == expected_base_addrs[i], (
|
||||
f"Entry {i}: Expected base address {expected_base_addrs[i]}, "
|
||||
f"got {base_addr}"
|
||||
)
|
||||
|
||||
# Verify get_xfer_descs was called with blocks_data
|
||||
assert mock_wrapper_instance.get_xfer_descs.called
|
||||
@@ -922,16 +957,17 @@ def test_register_kv_caches(dist_init):
|
||||
|
||||
# Validate blocks_data structure and size
|
||||
expected_blocks_count = 8
|
||||
assert len(blocks_data) == expected_blocks_count, \
|
||||
f"Expected {expected_blocks_count} blocks, " \
|
||||
f"got {len(blocks_data)}"
|
||||
assert len(blocks_data) == expected_blocks_count, (
|
||||
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
|
||||
)
|
||||
|
||||
expected_block_len = expected_tensor_size // 2
|
||||
for i, block_entry in enumerate(blocks_data):
|
||||
block_start_addr, block_len, tp_rank = block_entry
|
||||
assert block_len == expected_block_len, \
|
||||
f"Block entry {i}: Expected block len {expected_block_len}, " \
|
||||
assert block_len == expected_block_len, (
|
||||
f"Block entry {i}: Expected block len {expected_block_len}, "
|
||||
f"got {block_len}"
|
||||
)
|
||||
|
||||
|
||||
class FakePlatform(Platform):
|
||||
@@ -940,24 +976,26 @@ class FakePlatform(Platform):
|
||||
@classmethod
|
||||
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
|
||||
"""
|
||||
Returns a mapping from device_type to a tuple of supported
|
||||
Returns a mapping from device_type to a tuple of supported
|
||||
kv_buffer_device for nixl.
|
||||
"""
|
||||
return {'oot': ('oot', )}
|
||||
return {"oot": ("oot",)}
|
||||
|
||||
@classmethod
|
||||
def get_nixl_memory_type(cls) -> Optional[str]:
|
||||
"""
|
||||
Returns the nixl memory type for the current platform.
|
||||
"""
|
||||
return 'VRAM'
|
||||
return "VRAM"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [
|
||||
("oot", "VRAM"),
|
||||
])
|
||||
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
|
||||
nixl_memory_type):
|
||||
@pytest.mark.parametrize(
|
||||
"kv_buffer_device, nixl_memory_type",
|
||||
[
|
||||
("oot", "VRAM"),
|
||||
],
|
||||
)
|
||||
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory_type):
|
||||
"""
|
||||
Test that register_kv_caches() passes the correct memory types from the
|
||||
config to the nixl_wrapper.
|
||||
@@ -966,15 +1004,30 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
|
||||
# Override the default memory types in the config
|
||||
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
_NIXL_SUPPORTED_DEVICE)
|
||||
_NIXL_SUPPORTED_DEVICE,
|
||||
)
|
||||
|
||||
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
|
||||
|
||||
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform",
|
||||
FakePlatform,
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE",
|
||||
_NIXL_SUPPORTED_DEVICE,
|
||||
),
|
||||
): # noqa: E501
|
||||
# Create connector and replace its worker with a fake one for isolation
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
@@ -985,22 +1038,23 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_shutdown_cleans_up_resources(dist_init):
|
||||
"""Test that shutdown() properly cleans up all resources."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
worker = NixlConnectorWorker(vllm_config,
|
||||
vllm_config.kv_transfer_config.engine_id)
|
||||
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
|
||||
nixl_wrapper = worker.nixl_wrapper
|
||||
|
||||
with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \
|
||||
patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \
|
||||
patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \
|
||||
patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \
|
||||
patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \
|
||||
patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg:
|
||||
|
||||
with (
|
||||
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
|
||||
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
|
||||
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
|
||||
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
|
||||
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
|
||||
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
|
||||
):
|
||||
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
|
||||
worker.src_xfer_side_handle = 456
|
||||
worker.dst_xfer_side_handles = {"engine1": 789}
|
||||
@@ -1028,7 +1082,8 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_aborted_request_removed_from_worker_in_batch(dist_init):
|
||||
"""
|
||||
Create and schedule a request so that P adds it to in-batch tracking via
|
||||
@@ -1040,9 +1095,9 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
# KVConnector Worker in P
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=0)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
|
||||
# Create a request that triggers do_remote_decode so that
|
||||
# the scheduler adds it to reqs_in_batch
|
||||
|
||||
@@ -14,27 +14,42 @@ from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
||||
OffloadingConnector, OffloadingConnectorMetadata)
|
||||
OffloadingConnector,
|
||||
OffloadingConnectorMetadata,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
get_request_block_hasher,
|
||||
init_none_hash,
|
||||
)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
|
||||
OffloadingManager, PrepareStoreOutput)
|
||||
from vllm.v1.kv_offload.abstract import (
|
||||
LoadStoreSpec,
|
||||
OffloadingEvent,
|
||||
OffloadingManager,
|
||||
PrepareStoreOutput,
|
||||
)
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
|
||||
TransferResult, TransferSpec)
|
||||
from vllm.v1.kv_offload.worker.worker import (
|
||||
OffloadingHandler,
|
||||
TransferResult,
|
||||
TransferSpec,
|
||||
)
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler,
|
||||
create_vllm_config)
|
||||
from .utils import (
|
||||
EOS_TOKEN_ID,
|
||||
create_model_runner_output,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
class MockLoadStoreSpec(LoadStoreSpec):
|
||||
|
||||
def __init__(self, block_hashes: Iterable[BlockHash]):
|
||||
self.block_hashes: list[BlockHash] = list(block_hashes)
|
||||
|
||||
@@ -47,7 +62,6 @@ class MockLoadStoreSpec(LoadStoreSpec):
|
||||
|
||||
|
||||
class MockOffloadingHandler(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.completed_transfers: list[TransferResult] = []
|
||||
self.completed_specs: list[TransferSpec] = []
|
||||
@@ -64,14 +78,14 @@ class MockOffloadingHandler(OffloadingHandler):
|
||||
|
||||
|
||||
class MockOffloadingSpec(OffloadingSpec):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.manager = MagicMock(spec=OffloadingManager)
|
||||
self.manager.lookup.return_value = 0
|
||||
self.manager.prepare_load = lambda block_hashes: (MockLoadStoreSpec(
|
||||
block_hashes))
|
||||
self.manager.prepare_load = lambda block_hashes: (
|
||||
MockLoadStoreSpec(block_hashes)
|
||||
)
|
||||
self.handler = MockOffloadingHandler()
|
||||
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
@@ -79,9 +93,7 @@ class MockOffloadingSpec(OffloadingSpec):
|
||||
|
||||
def get_handlers(
|
||||
self, _
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
|
||||
OffloadingHandler]]:
|
||||
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
|
||||
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
|
||||
|
||||
@@ -98,35 +110,35 @@ class TransferSummary:
|
||||
|
||||
|
||||
class RequestRunner:
|
||||
|
||||
def __init__(self, offloaded_block_size: int, gpu_block_size: int,
|
||||
num_gpu_blocks: int):
|
||||
def __init__(
|
||||
self, offloaded_block_size: int, gpu_block_size: int, num_gpu_blocks: int
|
||||
):
|
||||
self.offloaded_block_size: int = offloaded_block_size
|
||||
self.gpu_block_size: int = gpu_block_size
|
||||
self.num_gpu_blocks: int = num_gpu_blocks
|
||||
|
||||
self.req_id: int = -1
|
||||
|
||||
vllm_config = create_vllm_config(block_size=gpu_block_size,
|
||||
max_num_batched_tokens=1000)
|
||||
vllm_config = create_vllm_config(
|
||||
block_size=gpu_block_size, max_num_batched_tokens=1000
|
||||
)
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="OffloadingConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"spec_name": "MockOffloadingSpec",
|
||||
"spec_module_path":
|
||||
"tests.v1.kv_connector.unit.test_offloading_connector",
|
||||
"spec_module_path": "tests.v1.kv_connector.unit.test_offloading_connector",
|
||||
"block_size": offloaded_block_size,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
self.scheduler: Scheduler = create_scheduler(vllm_config,
|
||||
num_blocks=num_gpu_blocks)
|
||||
self.worker_connector = OffloadingConnector(vllm_config,
|
||||
KVConnectorRole.WORKER)
|
||||
self.scheduler: Scheduler = create_scheduler(
|
||||
vllm_config, num_blocks=num_gpu_blocks
|
||||
)
|
||||
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# register worker kv_caches to enable OffloadingWorker creations
|
||||
self.worker_connector.register_kv_caches(
|
||||
kv_caches={"a": torch.empty(0)})
|
||||
self.worker_connector.register_kv_caches(kv_caches={"a": torch.empty(0)})
|
||||
|
||||
# extract connector of scheduler
|
||||
scheduler_connector = self.scheduler.connector
|
||||
@@ -166,9 +178,9 @@ class RequestRunner:
|
||||
init_none_hash(sha256)
|
||||
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
|
||||
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0)
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(
|
||||
no_compile_layers={}, attn_metadata={}, virtual_engine=0
|
||||
)
|
||||
|
||||
def new_request(self, token_ids: list[int]):
|
||||
assert not self.scheduler.requests
|
||||
@@ -189,8 +201,7 @@ class RequestRunner:
|
||||
block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
|
||||
while self.pending_loads_count or self.pending_stores_count:
|
||||
for transfer_spec in (
|
||||
self.offloading_spec.get_completed_transfers()):
|
||||
for transfer_spec in self.offloading_spec.get_completed_transfers():
|
||||
src_spec, dst_spec = transfer_spec
|
||||
|
||||
if isinstance(src_spec, GPULoadStoreSpec):
|
||||
@@ -207,8 +218,7 @@ class RequestRunner:
|
||||
|
||||
gpu_block_indices: list[int] = []
|
||||
for block_id in gpu_spec.block_ids:
|
||||
gpu_block_indices.append(
|
||||
self.gpu_block_index[block_id.item()])
|
||||
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
|
||||
|
||||
# list of (block_hash, sub_block_offset)
|
||||
offload_addresses: list[Any] = []
|
||||
@@ -220,23 +230,26 @@ class RequestRunner:
|
||||
assert len(gpu_block_indices) == len(offload_addresses)
|
||||
|
||||
self.completed_stores.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses))
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
self.pending_stores_count -= 1
|
||||
else:
|
||||
remainder_sub_block_count = (len(offload_addresses) -
|
||||
len(gpu_block_indices))
|
||||
remainder_sub_block_count = len(offload_addresses) - len(
|
||||
gpu_block_indices
|
||||
)
|
||||
assert remainder_sub_block_count >= 0
|
||||
assert remainder_sub_block_count < block_size_factor
|
||||
offload_addresses = offload_addresses[
|
||||
remainder_sub_block_count:]
|
||||
offload_addresses = offload_addresses[remainder_sub_block_count:]
|
||||
|
||||
self.completed_loads.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses))
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
self.pending_loads_count -= 1
|
||||
|
||||
def _update_gpu_block_idx(self):
|
||||
for blocks in (self.scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks.values()):
|
||||
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks.values():
|
||||
for block_idx, block in enumerate(blocks):
|
||||
self.gpu_block_index[block.block_id] = block_idx
|
||||
|
||||
@@ -259,23 +272,20 @@ class RequestRunner:
|
||||
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
|
||||
|
||||
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
|
||||
self.pending_stores_count += len(
|
||||
kv_connector_metadata.reqs_to_store)
|
||||
self.pending_stores_count += len(kv_connector_metadata.reqs_to_store)
|
||||
|
||||
self.worker_connector.bind_connector_metadata(
|
||||
kv_connector_metadata)
|
||||
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
finished_sending, finished_recving = (
|
||||
self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
finished_sending, finished_recving = self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids
|
||||
)
|
||||
|
||||
self.worker_connector.clear_connector_metadata()
|
||||
|
||||
@@ -283,13 +293,13 @@ class RequestRunner:
|
||||
reqs=self.scheduler.running,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
token_id=token_id)
|
||||
token_id=token_id,
|
||||
)
|
||||
|
||||
if self.scheduler.running:
|
||||
token_id = next(tokens_iter, None)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
self._wait_for_transfers()
|
||||
|
||||
@@ -300,24 +310,24 @@ class RequestRunner:
|
||||
while self.scheduler.requests:
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
finished_sending, finished_recving = (
|
||||
self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
finished_sending, finished_recving = self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids
|
||||
)
|
||||
|
||||
assert not finished_recving
|
||||
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending)
|
||||
finished_sending=finished_sending
|
||||
)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
def run(
|
||||
self,
|
||||
decoded_tokens: list[int],
|
||||
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
|
||||
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
|
||||
self,
|
||||
decoded_tokens: list[int],
|
||||
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
|
||||
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
|
||||
):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
@@ -337,23 +347,23 @@ class RequestRunner:
|
||||
loaded_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_loads:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses):
|
||||
transfer.gpu_block_indices, transfer.offload_addresses
|
||||
):
|
||||
loaded_gpu_block_indexes.add(gpu_block_idx)
|
||||
assert gpu_block_idx == self.offloaded[offloaded_address]
|
||||
|
||||
assert (
|
||||
set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes)
|
||||
assert set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes
|
||||
self.completed_loads.clear()
|
||||
|
||||
stored_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_stores:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses):
|
||||
transfer.gpu_block_indices, transfer.offload_addresses
|
||||
):
|
||||
stored_gpu_block_indexes.add(gpu_block_idx)
|
||||
self.offloaded[offloaded_address] = gpu_block_idx
|
||||
|
||||
assert (
|
||||
set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes)
|
||||
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
|
||||
self.completed_stores.clear()
|
||||
|
||||
|
||||
@@ -362,9 +372,11 @@ def request_runner():
|
||||
runners = []
|
||||
|
||||
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
|
||||
runner = RequestRunner(offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
runner = RequestRunner(
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
runners.append(runner)
|
||||
return runner
|
||||
|
||||
@@ -386,15 +398,18 @@ def test_offloading_connector(request_runner):
|
||||
num_gpu_blocks = 100
|
||||
block_size_factor = offloaded_block_size // gpu_block_size
|
||||
|
||||
runner = request_runner(offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
runner = request_runner(
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
|
||||
# 3 blocks, store just the middle block (skip first and last)
|
||||
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
|
||||
)
|
||||
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
|
||||
|
||||
# add block missing 1 token -> no offload
|
||||
@@ -402,21 +417,24 @@ def test_offloading_connector(request_runner):
|
||||
runner.manager.prepare_store.assert_not_called()
|
||||
|
||||
# +1 token -> single block, fail prepare_store
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: None
|
||||
runner.manager.prepare_store.side_effect = lambda block_hashes: None
|
||||
runner.run(decoded_tokens=[0])
|
||||
runner.manager.prepare_store.assert_called()
|
||||
|
||||
# 1 more block, now set block_hashes_to_store = []
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size)
|
||||
|
||||
# 1 more block, now check touch was called with all 6 blocks
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size,
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17))
|
||||
)
|
||||
runner.run(
|
||||
decoded_tokens=[0] * offloaded_block_size,
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17),
|
||||
)
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes1 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes1) == 6
|
||||
@@ -426,9 +444,10 @@ def test_offloading_connector(request_runner):
|
||||
|
||||
# create a new request differing only on the last token
|
||||
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
|
||||
runner.run(decoded_tokens=[0],
|
||||
expected_stored_gpu_block_indexes=tuple(
|
||||
range(6 * block_size_factor)))
|
||||
runner.run(
|
||||
decoded_tokens=[0],
|
||||
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
|
||||
)
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes2 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes2) == 6
|
||||
@@ -441,17 +460,20 @@ def test_offloading_connector(request_runner):
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# full_block_tokens - num_computed_tokens < offloaded_block_size
|
||||
runner.new_request(token_ids=[0] * gpu_block_size + [1] *
|
||||
(offloaded_block_size - gpu_block_size))
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
runner.new_request(
|
||||
token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size)
|
||||
)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_not_called()
|
||||
|
||||
# single block lookup with no hits
|
||||
runner.new_request(token_ids=[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_called()
|
||||
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
|
||||
@@ -459,34 +481,37 @@ def test_offloading_connector(request_runner):
|
||||
# single block lookup with a hit
|
||||
runner.scheduler.reset_prefix_cache()
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_loaded_gpu_block_indexes=(0, 1, 2))
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)
|
||||
)
|
||||
|
||||
# single block lookup with a hit in a middle block
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 2 +
|
||||
[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
runner.new_request(
|
||||
token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size
|
||||
)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_loaded_gpu_block_indexes=(3, 4, 5))
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)
|
||||
)
|
||||
|
||||
# test take_events
|
||||
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
|
||||
return [BlockHash(str(i).encode()) for i in int_hashes]
|
||||
|
||||
def take_events() -> Iterable[OffloadingEvent]:
|
||||
yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]),
|
||||
block_size=16,
|
||||
medium="A",
|
||||
removed=False)
|
||||
yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]),
|
||||
block_size=32,
|
||||
medium="B",
|
||||
removed=True)
|
||||
yield OffloadingEvent(
|
||||
block_hashes=to_hashes([1, 2, 3]), block_size=16, medium="A", removed=False
|
||||
)
|
||||
yield OffloadingEvent(
|
||||
block_hashes=to_hashes([4, 5, 6]), block_size=32, medium="B", removed=True
|
||||
)
|
||||
|
||||
runner.manager.take_events.side_effect = take_events
|
||||
events = list(runner.scheduler_connector.take_events())
|
||||
|
||||
@@ -12,22 +12,25 @@ pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
|
||||
def __init__(self,
|
||||
finished_sending: Optional[set[str]] = None,
|
||||
finished_recving: Optional[set[str]] = None,
|
||||
invalid_block_ids: Optional[set[int]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
finished_sending: Optional[set[str]] = None,
|
||||
finished_recving: Optional[set[str]] = None,
|
||||
invalid_block_ids: Optional[set[int]] = None,
|
||||
):
|
||||
self.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set())
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"DummyModelRunnerOutput("
|
||||
f"finished_sending={self.kv_connector_output.finished_sending},"
|
||||
f"finished_recving={self.kv_connector_output.finished_recving})"
|
||||
f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})")
|
||||
f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})"
|
||||
)
|
||||
|
||||
|
||||
def test_aggregate_workers_output():
|
||||
@@ -44,8 +47,9 @@ def test_aggregate_workers_output():
|
||||
assert aggregated.finished_recving is None
|
||||
assert not aggregated.invalid_block_ids
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output1 = DummyModelRunnerOutput(
|
||||
finished_sending={"req1"}, finished_recving={"req2"}
|
||||
)
|
||||
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
@@ -57,26 +61,27 @@ def test_aggregate_workers_output():
|
||||
assert aggregated.invalid_block_ids == {1}
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'})
|
||||
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending == {'req1'}
|
||||
assert aggregated.finished_sending == {"req1"}
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {2}
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
|
||||
output2 = DummyModelRunnerOutput(finished_recving={'req2'},
|
||||
invalid_block_ids={4, 5})
|
||||
output2 = DummyModelRunnerOutput(
|
||||
finished_recving={"req2"}, invalid_block_ids={4, 5}
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {'req2'}
|
||||
assert aggregated.finished_recving == {"req2"}
|
||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
||||
|
||||
|
||||
@@ -104,8 +109,9 @@ def test_async_aggregate_workers_output():
|
||||
future2 = Future()
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output1 = DummyModelRunnerOutput(
|
||||
finished_sending={"req1"}, finished_recving={"req2"}
|
||||
)
|
||||
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
@@ -123,7 +129,7 @@ def test_async_aggregate_workers_output():
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'})
|
||||
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
|
||||
@@ -131,7 +137,7 @@ def test_async_aggregate_workers_output():
|
||||
aggregated = result_future.result()
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending == {'req1'}
|
||||
assert aggregated.finished_sending == {"req1"}
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {2}
|
||||
|
||||
@@ -140,8 +146,9 @@ def test_async_aggregate_workers_output():
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
|
||||
output2 = DummyModelRunnerOutput(finished_recving={'req2'},
|
||||
invalid_block_ids={4, 5})
|
||||
output2 = DummyModelRunnerOutput(
|
||||
finished_recving={"req2"}, invalid_block_ids={4, 5}
|
||||
)
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
|
||||
@@ -150,5 +157,5 @@ def test_async_aggregate_workers_output():
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {'req2'}
|
||||
assert aggregated.finished_recving == {"req2"}
|
||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
||||
|
||||
@@ -7,8 +7,13 @@ import pytest
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
||||
create_request, create_scheduler, create_vllm_config)
|
||||
from .utils import (
|
||||
assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@@ -24,11 +29,13 @@ def test_basic_lifecycle():
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
@@ -43,8 +50,9 @@ def test_basic_lifecycle():
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
|
||||
# Ensure the request is finished after 1 token.
|
||||
assert request.is_finished()
|
||||
@@ -60,7 +68,8 @@ def test_basic_lifecycle():
|
||||
|
||||
# ... but blocks should not be freed.
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
|
||||
@@ -92,7 +101,8 @@ def test_basic_lifecycle():
|
||||
# (3b): execute_model()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending={request_id})
|
||||
finished_sending={request_id}
|
||||
)
|
||||
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
@@ -110,11 +120,13 @@ def test_short_prompt_lifecycle():
|
||||
# Not enough tokens for full block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_TOKENS = BLOCK_SIZE // 2
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
@@ -132,14 +144,15 @@ def test_short_prompt_lifecycle():
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
|
||||
assert (len(kv_transfer_params["remote_block_ids"]) == 1)
|
||||
assert len(kv_transfer_params["remote_block_ids"]) == 1
|
||||
|
||||
# Confirm we do not have any memory leaks after req lifecycle.
|
||||
# We need to mark sending finish to clear data for persistent batch.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# Use create_model_runner_output to pass kv_connector_output along
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request], finished_sending={request.request_id})
|
||||
reqs=[request], finished_sending={request.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
@@ -155,14 +168,15 @@ def test_prefix_cache_lifecycle():
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 3
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_normal = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS)
|
||||
request_normal = create_request(
|
||||
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS
|
||||
)
|
||||
|
||||
scheduler.add_request(request_normal)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal],
|
||||
use_eos=True)
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
@@ -174,10 +188,12 @@ def test_prefix_cache_lifecycle():
|
||||
NUM_EXTERNAL_FULL_BLOCKS -= 1
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
@@ -187,14 +203,13 @@ def test_prefix_cache_lifecycle():
|
||||
|
||||
# Ensure we send all block ids, including the partial blocks,
|
||||
# even if there is a cache hit.
|
||||
assert (len(
|
||||
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
|
||||
1))
|
||||
assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)
|
||||
|
||||
# STEP (2): Ensure it is freed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending={request_remote.request_id})
|
||||
finished_sending={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
@@ -7,8 +7,13 @@ import pytest
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from .utils import (assert_scheduler_empty, create_model_runner_output,
|
||||
create_request, create_scheduler, create_vllm_config)
|
||||
from .utils import (
|
||||
assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@@ -24,12 +29,15 @@ def test_basic_lifecycle():
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
START_FREE_BLOCK_QUEUE_SIZE = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
||||
)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
@@ -48,16 +56,16 @@ def test_basic_lifecycle():
|
||||
# Req waiting for KVs with no computed/scheduled toks ...
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request in scheduler.waiting
|
||||
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||
assert (request.num_computed_tokens == 0)
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
# ... but should have (uncached) blocks allocated to it.
|
||||
block_pool = scheduler.kv_cache_manager.block_pool
|
||||
assert (block_pool.free_block_queue.num_free_blocks
|
||||
< START_FREE_BLOCK_QUEUE_SIZE)
|
||||
assert block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE
|
||||
assert len(block_pool.cached_block_hash_to_block) == 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block._block_hash is None
|
||||
|
||||
@@ -65,8 +73,9 @@ def test_basic_lifecycle():
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||
|
||||
# STEP (2):
|
||||
@@ -78,13 +87,15 @@ def test_basic_lifecycle():
|
||||
# (2b): forward(): request finishes recv.
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving={request_id})
|
||||
finished_recving={request_id}
|
||||
)
|
||||
|
||||
# (2c): update_from_output():
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# STEP (3):
|
||||
# (3a): schedule(): this should actually schedule.
|
||||
@@ -94,10 +105,11 @@ def test_basic_lifecycle():
|
||||
# Confirm the block are actually allocated.
|
||||
num_hashed_blocks = 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
||||
num_hashed_blocks += 1 if block._block_hash is not None else 0
|
||||
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
|
||||
# Confirm the rest of the prompt is scheduled in this step.
|
||||
@@ -105,7 +117,7 @@ def test_basic_lifecycle():
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
|
||||
num_computed_tokens = scheduled_req.num_computed_tokens
|
||||
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
|
||||
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
|
||||
assert num_scheduled_tokens == total_prompt_tokens - num_computed_tokens
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
@@ -115,8 +127,9 @@ def test_basic_lifecycle():
|
||||
# Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
@@ -137,10 +150,12 @@ def test_interleaved_lifecycle():
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
request_local_a = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
@@ -169,8 +184,7 @@ def test_interleaved_lifecycle():
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b])
|
||||
model_runner_output = create_model_runner_output([request_local_a, request_local_b])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 3: continue running, KVs not arrived yet.
|
||||
@@ -181,7 +195,8 @@ def test_interleaved_lifecycle():
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_local_a, request_local_b])
|
||||
reqs=[request_local_a, request_local_b]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
@@ -196,8 +211,8 @@ def test_interleaved_lifecycle():
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b],
|
||||
finished_recving={request_remote.request_id})
|
||||
[request_local_a, request_local_b], finished_recving={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 5: RECVed KVs are sent to ModelRunner.
|
||||
@@ -208,7 +223,8 @@ def test_interleaved_lifecycle():
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b, request_remote])
|
||||
[request_local_a, request_local_b, request_remote]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 6: Hit EOS and free.
|
||||
@@ -273,15 +289,17 @@ def test_no_spurious_prefix_caching():
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_local.request_id]
|
||||
0
|
||||
].req_to_blocks[request_local.request_id]
|
||||
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_remote.request_id]
|
||||
0
|
||||
].req_to_blocks[request_remote.request_id]
|
||||
|
||||
# Local should have cached blocks (but not all due to preallocate).
|
||||
num_hashed_blocks = 0
|
||||
for block in local_blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
||||
num_hashed_blocks += 1 if block._block_hash is not None else 0
|
||||
assert num_hashed_blocks > 0
|
||||
|
||||
# Remote blocks should not be cached.
|
||||
@@ -301,10 +319,12 @@ def test_full_block_prompt():
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
@@ -312,8 +332,11 @@ def test_full_block_prompt():
|
||||
# STEP (1): Initialize a recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# All blocks should be allocated.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
num_blocks = len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
]
|
||||
)
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
@@ -322,22 +345,25 @@ def test_full_block_prompt():
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving={request_id})
|
||||
finished_recving={request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# # STEP (3): Run as usual.
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# We need to recompute the final token of the prompt to generate
|
||||
# the first new token, so we should not have a new block.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
num_blocks = len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
]
|
||||
)
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
|
||||
NUM_TOKENS - 1)
|
||||
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
|
||||
assert scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1
|
||||
assert scheduler_output.num_scheduled_tokens[request_id] == 1
|
||||
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
@@ -345,8 +371,9 @@ def test_full_block_prompt():
|
||||
# # Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
@@ -375,13 +402,15 @@ def test_cannot_schedule_after_recv():
|
||||
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
|
||||
request_normal = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_LOCAL)
|
||||
request_remote = create_request(request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True)
|
||||
request_normal = create_request(
|
||||
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
|
||||
)
|
||||
request_remote = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
|
||||
scheduler.add_request(request_normal)
|
||||
@@ -402,7 +431,8 @@ def test_cannot_schedule_after_recv():
|
||||
# Step 3: finish recving (5 blocks in use)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], finished_recving={request_remote.request_id})
|
||||
reqs=[request_normal], finished_recving={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
@@ -411,7 +441,8 @@ def test_cannot_schedule_after_recv():
|
||||
# because the transfer is completed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal, request_remote])
|
||||
reqs=[request_normal, request_remote]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 0
|
||||
@@ -426,8 +457,9 @@ def test_cannot_schedule_after_recv():
|
||||
|
||||
# Step 6: finish the request, free it.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal],
|
||||
use_eos=True)
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
@@ -436,16 +468,19 @@ def test_cannot_schedule_after_recv():
|
||||
# request is retrieved from preempted list.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
assert (scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] ==
|
||||
NUM_PROMPT_BLOCKS * BLOCK_SIZE)
|
||||
assert (
|
||||
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
|
||||
== NUM_PROMPT_BLOCKS * BLOCK_SIZE
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 8: free everything.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote],
|
||||
use_eos=True)
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_remote], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
@@ -470,13 +505,15 @@ def test_cannot_recv():
|
||||
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
|
||||
|
||||
request_normal = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_LOCAL)
|
||||
request_remote = create_request(request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True)
|
||||
request_normal = create_request(
|
||||
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
|
||||
)
|
||||
request_remote = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
|
||||
scheduler.add_request(request_normal)
|
||||
@@ -495,12 +532,13 @@ def test_cannot_recv():
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
# Should not have KV transfer in progress.
|
||||
assert (request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
|
||||
# Step 3: finish the request, free it.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal],
|
||||
use_eos=True)
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
@@ -511,12 +549,13 @@ def test_cannot_recv():
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
|
||||
# Step 5: finish recving (5 blocks in use)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[], finished_recving={request_remote.request_id})
|
||||
reqs=[], finished_recving={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
@@ -530,8 +569,9 @@ def test_cannot_recv():
|
||||
|
||||
# Step 7: free everything.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote],
|
||||
use_eos=True)
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_remote], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
@@ -37,16 +37,22 @@ def _list_path(path):
|
||||
return list(path.iterdir())
|
||||
|
||||
|
||||
def run_test(tmp_path, processor, llm: LLM, question: str,
|
||||
image_urls: list[Image], expected_len: int, info: str):
|
||||
def run_test(
|
||||
tmp_path,
|
||||
processor,
|
||||
llm: LLM,
|
||||
question: str,
|
||||
image_urls: list[Image],
|
||||
expected_len: int,
|
||||
info: str,
|
||||
):
|
||||
"""
|
||||
One individual test to process the prompt and output base on 1 set of input
|
||||
Then check if the length in the storage path matches the expected length
|
||||
`info` introduces details or purpose of the individual test
|
||||
"""
|
||||
print(f"***info: {info}***")
|
||||
print(
|
||||
f"**Expected storage path length after llm generate: {expected_len}**")
|
||||
print(f"**Expected storage path length after llm generate: {expected_len}**")
|
||||
process_prompt(processor, llm, question, image_urls)
|
||||
|
||||
print(f"Path matched expected length: {_check_path_len(tmp_path)}")
|
||||
@@ -54,51 +60,42 @@ def run_test(tmp_path, processor, llm: LLM, question: str,
|
||||
|
||||
assert _check_path_len(tmp_path) == expected_len, (
|
||||
f"Expect storage path length {expected_len} ;",
|
||||
f"but end up {_check_path_len(tmp_path)} instead. ", f"Info: {info}")
|
||||
f"but end up {_check_path_len(tmp_path)} instead. ",
|
||||
f"Info: {info}",
|
||||
)
|
||||
|
||||
|
||||
def process_prompt(processor, llm: LLM, question: str,
|
||||
image_urls: list[Image]):
|
||||
def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
|
||||
"""
|
||||
Form the prompt based on the text and image input, then llm generate output
|
||||
"""
|
||||
placeholders = [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{encode_image_base64(image_pil)}"
|
||||
placeholders = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"},
|
||||
}
|
||||
} for image_pil in image_urls]
|
||||
for image_pil in image_urls
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt":
|
||||
prompt,
|
||||
**({
|
||||
"multi_modal_data": {
|
||||
"image": [*image_urls]
|
||||
}
|
||||
} if image_urls else {})
|
||||
"prompt": prompt,
|
||||
**({"multi_modal_data": {"image": [*image_urls]}} if image_urls else {}),
|
||||
},
|
||||
sampling_params=SAMPLING_PARAMS,
|
||||
)
|
||||
@@ -114,7 +111,7 @@ def process_prompt(processor, llm: LLM, question: str,
|
||||
def test_shared_storage_connector_hashes(tmp_path):
|
||||
"""
|
||||
Tests that SharedStorageConnector saves KV to the storage locations
|
||||
with proper hashes; that are unique for inputs with identical text but
|
||||
with proper hashes; that are unique for inputs with identical text but
|
||||
different images (same size), or same multiple images but different orders.
|
||||
"""
|
||||
# Using tmp_path as the storage path to store KV
|
||||
@@ -124,7 +121,8 @@ def test_shared_storage_connector_hashes(tmp_path):
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": str(tmp_path)})
|
||||
kv_connector_extra_config={"shared_storage_path": str(tmp_path)},
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
@@ -157,56 +155,88 @@ def test_shared_storage_connector_hashes(tmp_path):
|
||||
|
||||
# Prepare the input cases
|
||||
input_cases = [
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[image_1],
|
||||
expected_len=1,
|
||||
info="image_1 single input the first time."),
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[image_2],
|
||||
expected_len=2,
|
||||
info=("image_2 single input the first time. "
|
||||
"It is in same pixel size with image_1, yet it "
|
||||
"should be able to form a new unique hash.")),
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[image_1],
|
||||
expected_len=2,
|
||||
info=("image_1 single input the 2nd time. "
|
||||
"It should not form another new hash.")),
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[image_2],
|
||||
expected_len=2,
|
||||
info=("image_2 single input the 2nd time. "
|
||||
"It should not form another new hash.")),
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[image_1, image_2],
|
||||
expected_len=3,
|
||||
info="image_1 with image_2 input the first time."),
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[image_2, image_1],
|
||||
expected_len=4,
|
||||
info="The image order is swapped. Should form new hash."),
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[image_1, image_2],
|
||||
expected_len=4,
|
||||
info=("[image_1, image_2] input the 2nd time. "
|
||||
"It should not form another new hash.")),
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[image_2, image_1],
|
||||
expected_len=4,
|
||||
info=("[image_2, image_1] input the 2nd time. "
|
||||
"It should not form another new hash.")),
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[],
|
||||
expected_len=5,
|
||||
info="Pure text input test as a case-control"),
|
||||
InputCase(text=TEXT_PROMPTS[0],
|
||||
img=[],
|
||||
expected_len=5,
|
||||
info="Identical pure text input as a case-control"),
|
||||
InputCase(text=TEXT_PROMPTS[1],
|
||||
img=[],
|
||||
expected_len=6,
|
||||
info="Another pure text input as a case-control"),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1],
|
||||
expected_len=1,
|
||||
info="image_1 single input the first time.",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2],
|
||||
expected_len=2,
|
||||
info=(
|
||||
"image_2 single input the first time. "
|
||||
"It is in same pixel size with image_1, yet it "
|
||||
"should be able to form a new unique hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1],
|
||||
expected_len=2,
|
||||
info=(
|
||||
"image_1 single input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2],
|
||||
expected_len=2,
|
||||
info=(
|
||||
"image_2 single input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1, image_2],
|
||||
expected_len=3,
|
||||
info="image_1 with image_2 input the first time.",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2, image_1],
|
||||
expected_len=4,
|
||||
info="The image order is swapped. Should form new hash.",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1, image_2],
|
||||
expected_len=4,
|
||||
info=(
|
||||
"[image_1, image_2] input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2, image_1],
|
||||
expected_len=4,
|
||||
info=(
|
||||
"[image_2, image_1] input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[],
|
||||
expected_len=5,
|
||||
info="Pure text input test as a case-control",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[],
|
||||
expected_len=5,
|
||||
info="Identical pure text input as a case-control",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[1],
|
||||
img=[],
|
||||
expected_len=6,
|
||||
info="Another pure text input as a case-control",
|
||||
),
|
||||
]
|
||||
|
||||
# Run tests
|
||||
|
||||
@@ -8,19 +8,27 @@ from typing import Any, Callable, Optional
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||
SharedStorageConnector)
|
||||
SharedStorageConnector,
|
||||
)
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
@@ -42,14 +50,24 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
assert (
|
||||
len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].num_cached_block
|
||||
)
|
||||
== 0
|
||||
)
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
||||
)
|
||||
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
@@ -90,11 +108,13 @@ def create_vllm_config(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
return VllmConfig(scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"))
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
)
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
@@ -107,9 +127,9 @@ def create_scheduler(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
False))
|
||||
KVCacheGroupSpec(
|
||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||
)
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
@@ -151,16 +171,16 @@ def create_request(
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = dict(do_remote_prefill=False,
|
||||
do_remote_decode=True)
|
||||
kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = dict(do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(
|
||||
range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234)
|
||||
kv_transfer_params = dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
@@ -200,13 +220,19 @@ def create_model_runner_output(
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else token_id
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
kv_connector_output = None if (
|
||||
finished_sending is None and finished_recving is None
|
||||
and invalid_block_ids is None) else KVConnectorOutput(
|
||||
kv_connector_output = (
|
||||
None
|
||||
if (
|
||||
finished_sending is None
|
||||
and finished_recving is None
|
||||
and invalid_block_ids is None
|
||||
)
|
||||
else KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
)
|
||||
)
|
||||
|
||||
# Make output data structure.
|
||||
return ModelRunnerOutput(
|
||||
@@ -221,22 +247,30 @@ def create_model_runner_output(
|
||||
|
||||
|
||||
class TestSharedStorageConnector(SharedStorageConnector):
|
||||
|
||||
def __init__(self, config: VllmConfig, role):
|
||||
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
||||
self._connector = SharedStorageConnector(config, role)
|
||||
self.call_record: dict[str, int] = defaultdict(int)
|
||||
# Use a unique temp file per connector
|
||||
self._event_file = tempfile.gettempdir(
|
||||
) + f"/connector_{self.name}-{self.role.name}_events.log"
|
||||
self._event_file = (
|
||||
tempfile.gettempdir()
|
||||
+ f"/connector_{self.name}-{self.role.name}_events.log"
|
||||
)
|
||||
# Start with an empty file
|
||||
with open(self._event_file, "w") as _:
|
||||
pass
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in ("_connector", "call_record", "name", "_event_file",
|
||||
"__class__", "__dict__", "__getattribute__",
|
||||
"__init__"): # avoid recursion
|
||||
if name in (
|
||||
"_connector",
|
||||
"call_record",
|
||||
"name",
|
||||
"_event_file",
|
||||
"__class__",
|
||||
"__dict__",
|
||||
"__getattribute__",
|
||||
"__init__",
|
||||
): # avoid recursion
|
||||
return object.__getattribute__(self, name)
|
||||
if not hasattr(self._connector, name):
|
||||
return object.__getattribute__(self, name)
|
||||
@@ -255,21 +289,20 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
||||
if isinstance(arg, int):
|
||||
to_log.append(str(arg))
|
||||
elif isinstance(arg, KVCacheBlocks):
|
||||
to_log.append(
|
||||
f"num_blocks={[len(b) for b in arg.blocks]}")
|
||||
to_log.append(f"num_blocks={[len(b) for b in arg.blocks]}")
|
||||
|
||||
# Log the event as a line to the file
|
||||
try:
|
||||
with open(self._event_file, "a") as f:
|
||||
f.write(' '.join(to_log) + "\n")
|
||||
f.write(" ".join(to_log) + "\n")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not log event {name} "
|
||||
f"for {self.name}: {e}")
|
||||
print(f"[ERROR] Could not log event {name} for {self.name}: {e}")
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return attr
|
||||
|
||||
|
||||
KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__,
|
||||
TestSharedStorageConnector.__name__)
|
||||
KVConnectorFactory.register_connector(
|
||||
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user