Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
class BatchLogprobsComposition(Enum):
|
||||
"""Types of logprobs configs to include in test batch"""
|
||||
|
||||
NONE = 0
|
||||
SAMPLE = 1
|
||||
PROMPT = 2
|
||||
@@ -26,10 +27,10 @@ BatchLogprobsSpecType = list[tuple[Optional[int], Optional[int]]]
|
||||
|
||||
|
||||
def get_test_batch(
|
||||
batch_logprobs_composition: BatchLogprobsComposition
|
||||
batch_logprobs_composition: BatchLogprobsComposition,
|
||||
) -> BatchLogprobsSpecType:
|
||||
"""Generate logprobs configs for a batch of requests
|
||||
|
||||
|
||||
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
|
||||
num_prompt_logprobs. The batch logprobs configuration is the list of request
|
||||
logprobs configs.
|
||||
@@ -101,7 +102,7 @@ def assert_incr_detok_str_matches_non_incr_detok_str(
|
||||
msg: str,
|
||||
) -> None:
|
||||
"""Compare incrementally detok. text to non-incrementally detok. text
|
||||
|
||||
|
||||
Fail if the strings mismatch after non-alphanumeric characters are stripped
|
||||
out.
|
||||
|
||||
@@ -120,15 +121,15 @@ def assert_incr_detok_str_matches_non_incr_detok_str(
|
||||
tokens
|
||||
msg: error message if `assert` fails
|
||||
"""
|
||||
rgx = r'[^a-zA-Z0-9]+'
|
||||
assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub(
|
||||
rgx, '', non_incremental_detokenization_str)), (msg)
|
||||
rgx = r"[^a-zA-Z0-9]+"
|
||||
assert re.sub(rgx, "", incremental_detokenization_str) == re.sub(
|
||||
rgx, "", non_incremental_detokenization_str
|
||||
), msg
|
||||
|
||||
|
||||
def compute_correct_cumulative_logprob(
|
||||
completion_output: CompletionOutput) -> float:
|
||||
def compute_correct_cumulative_logprob(completion_output: CompletionOutput) -> float:
|
||||
"""Compute known-good value for evaluating cumulative logprob
|
||||
|
||||
|
||||
Args:
|
||||
completion_output: completion output from engine
|
||||
|
||||
@@ -146,12 +147,12 @@ def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
|
||||
return fake_logits
|
||||
|
||||
|
||||
def create_penalty_tensor(batch_size: int, penalty_value: float,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
return torch.full((batch_size, ),
|
||||
fill_value=penalty_value,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
def create_penalty_tensor(
|
||||
batch_size: int, penalty_value: float, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(batch_size,), fill_value=penalty_value, dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
|
||||
def create_prompt_tokens_tensor(
|
||||
@@ -170,6 +171,7 @@ def create_prompt_tokens_tensor(
|
||||
|
||||
class LogitsprocsTestFakes(NamedTuple):
|
||||
"""Wraps fake data structures to support testing"""
|
||||
|
||||
logits: torch.Tensor
|
||||
sampling_metadata: SamplingMetadata
|
||||
|
||||
@@ -178,15 +180,16 @@ class LogitsprocsTestFakes(NamedTuple):
|
||||
cls: type[LogitsProcessor],
|
||||
) -> Iterator[LogitsProcessor]:
|
||||
"""Yield logits processors of a specific class.
|
||||
|
||||
|
||||
Args:
|
||||
cls: :class:`LogitsProcessor` subclass
|
||||
|
||||
Returns:
|
||||
Iterator over logits processors
|
||||
"""
|
||||
return (lp for lp in self.sampling_metadata.logitsprocs.all
|
||||
if isinstance(lp, cls))
|
||||
return (
|
||||
lp for lp in self.sampling_metadata.logitsprocs.all if isinstance(lp, cls)
|
||||
)
|
||||
|
||||
def get_logitsprocs(self) -> Iterator[LogitsProcessor]:
|
||||
"""Iterator over all logits processors."""
|
||||
@@ -208,8 +211,7 @@ def fake_apply_logitsprocs(
|
||||
slice_indices: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""Imitate application of logits processors in engine core"""
|
||||
logits = test_fakes.logits[torch.tensor(slice_indices,
|
||||
dtype=torch.long)].clone()
|
||||
logits = test_fakes.logits[torch.tensor(slice_indices, dtype=torch.long)].clone()
|
||||
for processor in test_fakes.get_logitsprocs():
|
||||
logits = processor.apply(logits)
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user