[Frontend] API support for beam search for MQLLMEngine (#9117)

This commit is contained in:
Brendan Wong
2024-10-07 22:51:43 -07:00
committed by GitHub
parent e1faa2a598
commit 8c746226c9
8 changed files with 215 additions and 106 deletions

View File

@@ -1370,22 +1370,3 @@ class AtomicCounter:
@property
def value(self):
return self._value
def get_beam_search_score(
tokens: List[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1
return cumulative_logprob / (seq_len**length_penalty)