Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -3,7 +3,8 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union
from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig, VllmConfig
@@ -29,23 +30,19 @@ class EngineClient(ABC):
@property
@abstractmethod
def is_running(self) -> bool:
...
def is_running(self) -> bool: ...
@property
@abstractmethod
def is_stopped(self) -> bool:
...
def is_stopped(self) -> bool: ...
@property
@abstractmethod
def errored(self) -> bool:
...
def errored(self) -> bool: ...
@property
@abstractmethod
def dead_error(self) -> BaseException:
...
def dead_error(self) -> BaseException: ...
@abstractmethod
def generate(
@@ -71,7 +68,6 @@ class EngineClient(ABC):
params: BeamSearchParams,
lora_request: Optional[LoRARequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
@@ -112,8 +108,7 @@ class EngineClient(ABC):
tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function(
eos_token_id, length_penalty)
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
@@ -121,35 +116,49 @@ class EngineClient(ABC):
temperature=temperature,
)
all_beams = [
BeamSearchSequence(tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request)
BeamSearchSequence(
tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request,
)
]
completed = []
for _ in range(max_tokens):
prompts_batch, lora_req_batch = zip(*[(
TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs),
beam.lora_request,
) for beam in all_beams])
prompts_batch, lora_req_batch = zip(
*[
(
TokensPrompt(
prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs,
),
beam.lora_request,
)
for beam in all_beams
]
)
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, (individual_prompt,
lora_req) in enumerate(zip(prompts_batch, lora_req_batch)):
for i, (individual_prompt, lora_req) in enumerate(
zip(prompts_batch, lora_req_batch)
):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt,
beam_search_params,
request_id_item,
lora_request=lora_req)))
self.generate(
individual_prompt,
beam_search_params,
request_id_item,
lora_request=lora_req,
)
)
)
tasks.append(task)
output = await asyncio.gather(*tasks)
@@ -163,32 +172,31 @@ class EngineClient(ABC):
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
if token_id == eos_token_id and \
not ignore_eos:
if token_id == eos_token_id and not ignore_eos:
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens +
[token_id] if include_stop_str_in_output
tokens=current_beam.tokens + [token_id]
if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
finish_reason="stop",
stop_reason=eos_token_id))
stop_reason=eos_token_id,
)
)
else:
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.
multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs))
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
)
)
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
@@ -198,7 +206,7 @@ class EngineClient(ABC):
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
if (beam.tokens[-1] == eos_token_id and not ignore_eos):
if beam.tokens[-1] == eos_token_id and not ignore_eos:
# Skip the eos token in the text.
tokens = beam.tokens[tokenized_length:-1]
else:
@@ -209,19 +217,23 @@ class EngineClient(ABC):
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason if
beam.finish_reason is not None else "length",
stop_reason=beam.stop_reason)
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason
if beam.finish_reason is not None
else "length",
stop_reason=beam.stop_reason,
)
for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None)
prompt_logprobs=None,
)
@abstractmethod
def encode(
@@ -271,12 +283,10 @@ class EngineClient(ABC):
raise NotImplementedError
@abstractmethod
async def is_tracing_enabled(self) -> bool:
...
async def is_tracing_enabled(self) -> bool: ...
@abstractmethod
async def do_log_stats(self) -> None:
...
async def do_log_stats(self) -> None: ...
@abstractmethod
async def check_health(self) -> None:
@@ -299,8 +309,7 @@ class EngineClient(ABC):
...
@abstractmethod
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
async def reset_prefix_cache(self, device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
...
@@ -324,17 +333,19 @@ class EngineClient(ABC):
"""Load a new LoRA adapter into the engine for future requests."""
...
async def scale_elastic_ep(self,
new_data_parallel_size: int,
drain_timeout: int = 300) -> None:
async def scale_elastic_ep(
self, new_data_parallel_size: int, drain_timeout: int = 300
) -> None:
"""Scale the engine"""
raise NotImplementedError
async def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
async def collective_rpc(
self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
):
"""Perform a collective RPC call to the given path."""
raise NotImplementedError