[Frontend] API support for beam search for MQLLMEngine (#9117)
This commit is contained in:
@@ -1370,22 +1370,3 @@ class AtomicCounter:
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
|
||||
def get_beam_search_score(
|
||||
tokens: List[int],
|
||||
cumulative_logprob: float,
|
||||
eos_token_id: int,
|
||||
length_penalty: float = 1.0,
|
||||
) -> float:
|
||||
"""Calculate the beam search score with length penalty.
|
||||
|
||||
Adapted from
|
||||
|
||||
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||
"""
|
||||
seq_len = len(tokens)
|
||||
if tokens[-1] == eos_token_id:
|
||||
seq_len -= 1
|
||||
|
||||
return cumulative_logprob / (seq_len**length_penalty)
|
||||
|
||||
Reference in New Issue
Block a user