Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -13,8 +13,12 @@ from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreRequest, FinishReason)
|
||||
from vllm.v1.engine import (
|
||||
EngineCoreEvent,
|
||||
EngineCoreEventType,
|
||||
EngineCoreRequest,
|
||||
FinishReason,
|
||||
)
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
@@ -24,7 +28,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Request:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -41,8 +44,7 @@ class Request:
|
||||
cache_salt: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
block_hasher: Optional[Callable[["Request"],
|
||||
list["BlockHash"]]] = None,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
@@ -53,8 +55,7 @@ class Request:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.structured_output_request = structured_output_request
|
||||
self.arrival_time = arrival_time if arrival_time is not None else \
|
||||
time.time()
|
||||
self.arrival_time = arrival_time if arrival_time is not None else time.time()
|
||||
|
||||
self.status = RequestStatus.WAITING
|
||||
self.use_structured_output = False
|
||||
@@ -76,20 +77,23 @@ class Request:
|
||||
self.use_structured_output = True
|
||||
|
||||
if sampling_params.extra_args is not None:
|
||||
self.kv_transfer_params = \
|
||||
sampling_params.extra_args.get("kv_transfer_params")
|
||||
self.kv_transfer_params = sampling_params.extra_args.get(
|
||||
"kv_transfer_params"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"sampling_params and pooling_params can't both be unset")
|
||||
raise ValueError("sampling_params and pooling_params can't both be unset")
|
||||
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_embeds = prompt_embeds
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds)
|
||||
prompt_token_ids, prompt_embeds
|
||||
)
|
||||
self._output_token_ids: list[int] = []
|
||||
self._all_token_ids: list[int] = self.prompt_token_ids.copy(
|
||||
) if self.prompt_token_ids is not None else [0
|
||||
] * self.num_prompt_tokens
|
||||
self._all_token_ids: list[int] = (
|
||||
self.prompt_token_ids.copy()
|
||||
if self.prompt_token_ids is not None
|
||||
else [0] * self.num_prompt_tokens
|
||||
)
|
||||
self.num_output_placeholders = 0 # Used in async scheduling.
|
||||
self.spec_token_ids: list[int] = []
|
||||
self.num_computed_tokens = 0
|
||||
@@ -119,16 +123,16 @@ class Request:
|
||||
self.num_preemptions = 0
|
||||
|
||||
self.block_hashes: list[BlockHash] = []
|
||||
self.get_hash_new_full_blocks: Optional[Callable[
|
||||
[], list[BlockHash]]] = None
|
||||
self.get_hash_new_full_blocks: Optional[Callable[[], list[BlockHash]]] = None
|
||||
if block_hasher is not None:
|
||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
||||
self.block_hashes = self.get_hash_new_full_blocks()
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(
|
||||
cls, request: EngineCoreRequest,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
cls,
|
||||
request: EngineCoreRequest,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]],
|
||||
) -> "Request":
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
@@ -142,8 +146,10 @@ class Request:
|
||||
arrival_time=request.arrival_time,
|
||||
lora_request=request.lora_request,
|
||||
structured_output_request=StructuredOutputRequest(
|
||||
sampling_params=request.sampling_params) \
|
||||
if request.sampling_params else None,
|
||||
sampling_params=request.sampling_params
|
||||
)
|
||||
if request.sampling_params
|
||||
else None,
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
trace_headers=request.trace_headers,
|
||||
@@ -207,6 +213,7 @@ class Request:
|
||||
|
||||
class RequestStatus(enum.IntEnum):
|
||||
"""Status of a request."""
|
||||
|
||||
WAITING = enum.auto()
|
||||
WAITING_FOR_FSM = enum.auto()
|
||||
WAITING_FOR_REMOTE_KVS = enum.auto()
|
||||
@@ -227,8 +234,7 @@ class RequestStatus(enum.IntEnum):
|
||||
return status > RequestStatus.PREEMPTED
|
||||
|
||||
@staticmethod
|
||||
def get_finished_reason(
|
||||
status: "RequestStatus") -> Union[FinishReason, None]:
|
||||
def get_finished_reason(status: "RequestStatus") -> Union[FinishReason, None]:
|
||||
return _FINISHED_REASON_MAP.get(status)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user