Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,18 +4,29 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem,
|
||||
PlaceholderRange,
|
||||
)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
@@ -37,7 +48,7 @@ def create_scheduler(
|
||||
skip_tokenizer_init: bool = False,
|
||||
async_scheduling: bool = False,
|
||||
) -> Union[Scheduler, AsyncScheduler]:
|
||||
'''Create scheduler under test.
|
||||
"""Create scheduler under test.
|
||||
|
||||
Args:
|
||||
model: model under test
|
||||
@@ -49,7 +60,7 @@ def create_scheduler(
|
||||
|
||||
Returns:
|
||||
{class}`Scheduler` instance
|
||||
'''
|
||||
"""
|
||||
if max_model_len is None:
|
||||
max_model_len = max_num_batched_tokens
|
||||
scheduler_config = SchedulerConfig(
|
||||
@@ -69,9 +80,11 @@ def create_scheduler(
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||
'enable_prefix_caching': enable_prefix_caching
|
||||
})
|
||||
kwargs_cache = (
|
||||
{}
|
||||
if enable_prefix_caching is None
|
||||
else {"enable_prefix_caching": enable_prefix_caching}
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
@@ -79,16 +92,21 @@ def create_scheduler(
|
||||
cache_dtype="auto",
|
||||
**kwargs_cache,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
) if use_kv_connector else None
|
||||
kv_transfer_config = (
|
||||
KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
)
|
||||
if use_kv_connector
|
||||
else None
|
||||
)
|
||||
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
if num_speculative_tokens is not None:
|
||||
speculative_config = SpeculativeConfig(
|
||||
model="ngram", num_speculative_tokens=num_speculative_tokens)
|
||||
model="ngram", num_speculative_tokens=num_speculative_tokens
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
@@ -101,9 +119,9 @@ def create_scheduler(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
False))
|
||||
KVCacheGroupSpec(
|
||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||
)
|
||||
],
|
||||
)
|
||||
cache_config.num_gpu_blocks = num_blocks
|
||||
@@ -135,10 +153,12 @@ def create_requests(
|
||||
_none_hash_initialized = True
|
||||
|
||||
block_hasher = get_request_block_hasher(block_size, sha256)
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
sampling_params = SamplingParams(
|
||||
ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
mm_features = []
|
||||
@@ -152,11 +172,11 @@ def create_requests(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image")
|
||||
modality="image",
|
||||
)
|
||||
mm_features.append(mm_feature)
|
||||
|
||||
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
|
||||
num_tokens)
|
||||
prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
|
||||
request = Request(
|
||||
request_id=f"{i}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
|
||||
Reference in New Issue
Block a user