Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,7 @@ class BeamSearchSequence:
|
||||
The text field is optional and will only be filled when the sequence is
|
||||
about to be returned to the user.
|
||||
"""
|
||||
|
||||
# The tokens include the prompt.
|
||||
tokens: list[int]
|
||||
logprobs: list[dict[int, Logprob]]
|
||||
@@ -36,11 +37,11 @@ class BeamSearchOutput:
|
||||
It contains the list of the best beam search sequences.
|
||||
The length of the list is equal to the beam width.
|
||||
"""
|
||||
|
||||
sequences: list[BeamSearchSequence]
|
||||
|
||||
|
||||
class BeamSearchInstance:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
@@ -79,9 +80,9 @@ def get_beam_search_score(
|
||||
|
||||
|
||||
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
||||
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
|
||||
length_penalty)
|
||||
return get_beam_search_score(
|
||||
x.tokens, x.cum_logprob, eos_token_id, length_penalty
|
||||
)
|
||||
|
||||
return sort_beams_key
|
||||
|
||||
Reference in New Issue
Block a user