Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,19 +8,27 @@ from typing import Any, Callable, Optional
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||
SharedStorageConnector)
|
||||
SharedStorageConnector,
|
||||
)
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
@@ -42,14 +50,24 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
assert (
|
||||
len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].num_cached_block
|
||||
)
|
||||
== 0
|
||||
)
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
||||
)
|
||||
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
@@ -90,11 +108,13 @@ def create_vllm_config(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
return VllmConfig(scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"))
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
)
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
@@ -107,9 +127,9 @@ def create_scheduler(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
False))
|
||||
KVCacheGroupSpec(
|
||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||
)
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
@@ -151,16 +171,16 @@ def create_request(
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = dict(do_remote_prefill=False,
|
||||
do_remote_decode=True)
|
||||
kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = dict(do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(
|
||||
range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234)
|
||||
kv_transfer_params = dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
@@ -200,13 +220,19 @@ def create_model_runner_output(
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else token_id
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
kv_connector_output = None if (
|
||||
finished_sending is None and finished_recving is None
|
||||
and invalid_block_ids is None) else KVConnectorOutput(
|
||||
kv_connector_output = (
|
||||
None
|
||||
if (
|
||||
finished_sending is None
|
||||
and finished_recving is None
|
||||
and invalid_block_ids is None
|
||||
)
|
||||
else KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
)
|
||||
)
|
||||
|
||||
# Make output data structure.
|
||||
return ModelRunnerOutput(
|
||||
@@ -221,22 +247,30 @@ def create_model_runner_output(
|
||||
|
||||
|
||||
class TestSharedStorageConnector(SharedStorageConnector):
|
||||
|
||||
def __init__(self, config: VllmConfig, role):
|
||||
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
||||
self._connector = SharedStorageConnector(config, role)
|
||||
self.call_record: dict[str, int] = defaultdict(int)
|
||||
# Use a unique temp file per connector
|
||||
self._event_file = tempfile.gettempdir(
|
||||
) + f"/connector_{self.name}-{self.role.name}_events.log"
|
||||
self._event_file = (
|
||||
tempfile.gettempdir()
|
||||
+ f"/connector_{self.name}-{self.role.name}_events.log"
|
||||
)
|
||||
# Start with an empty file
|
||||
with open(self._event_file, "w") as _:
|
||||
pass
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in ("_connector", "call_record", "name", "_event_file",
|
||||
"__class__", "__dict__", "__getattribute__",
|
||||
"__init__"): # avoid recursion
|
||||
if name in (
|
||||
"_connector",
|
||||
"call_record",
|
||||
"name",
|
||||
"_event_file",
|
||||
"__class__",
|
||||
"__dict__",
|
||||
"__getattribute__",
|
||||
"__init__",
|
||||
): # avoid recursion
|
||||
return object.__getattribute__(self, name)
|
||||
if not hasattr(self._connector, name):
|
||||
return object.__getattribute__(self, name)
|
||||
@@ -255,21 +289,20 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
||||
if isinstance(arg, int):
|
||||
to_log.append(str(arg))
|
||||
elif isinstance(arg, KVCacheBlocks):
|
||||
to_log.append(
|
||||
f"num_blocks={[len(b) for b in arg.blocks]}")
|
||||
to_log.append(f"num_blocks={[len(b) for b in arg.blocks]}")
|
||||
|
||||
# Log the event as a line to the file
|
||||
try:
|
||||
with open(self._event_file, "a") as f:
|
||||
f.write(' '.join(to_log) + "\n")
|
||||
f.write(" ".join(to_log) + "\n")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not log event {name} "
|
||||
f"for {self.name}: {e}")
|
||||
print(f"[ERROR] Could not log event {name} for {self.name}: {e}")
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return attr
|
||||
|
||||
|
||||
KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__,
|
||||
TestSharedStorageConnector.__name__)
|
||||
KVConnectorFactory.register_connector(
|
||||
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user