Align vLLM's beam search implementation with HF generate (#857)
This commit is contained in:
@@ -69,6 +69,9 @@ class SequenceData:
|
||||
def get_len(self) -> int:
|
||||
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
||||
|
||||
def get_prompt_len(self) -> int:
|
||||
return len(self.prompt_token_ids)
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return len(self.output_token_ids)
|
||||
|
||||
@@ -155,6 +158,9 @@ class Sequence:
|
||||
def get_len(self) -> int:
|
||||
return self.data.get_len()
|
||||
|
||||
def get_prompt_len(self) -> int:
|
||||
return self.data.get_prompt_len()
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return self.data.get_output_len()
|
||||
|
||||
@@ -170,14 +176,32 @@ class Sequence:
|
||||
def get_cumulative_logprob(self) -> float:
|
||||
return self.data.cumulative_logprob
|
||||
|
||||
def get_beam_search_score(self,
|
||||
length_penalty: float = 0.0,
|
||||
seq_len: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None) -> float:
|
||||
"""Calculate the beam search score with length penalty.
|
||||
|
||||
Adapted from
|
||||
|
||||
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||
"""
|
||||
if seq_len is None:
|
||||
seq_len = self.get_len()
|
||||
# Note: HF implementation does not count the EOS token
|
||||
# towards the length, we align with that here for testing.
|
||||
if (eos_token_id is not None
|
||||
and self.get_last_token_id() == eos_token_id):
|
||||
seq_len -= 1
|
||||
return self.get_cumulative_logprob() / (seq_len**length_penalty)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return SequenceStatus.is_finished(self.status)
|
||||
|
||||
def fork(self, child_seq: "Sequence") -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(
|
||||
self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.data = copy.deepcopy(self.data)
|
||||
def fork(self, new_seq_id: int) -> "Sequence":
|
||||
new_seq = copy.deepcopy(self)
|
||||
new_seq.seq_id = new_seq_id
|
||||
return new_seq
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"Sequence(seq_id={self.seq_id}, "
|
||||
@@ -203,35 +227,66 @@ class SequenceGroup:
|
||||
arrival_time: float,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs = seqs
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
self.sampling_params = sampling_params
|
||||
self.arrival_time = arrival_time
|
||||
|
||||
def get_max_num_running_seqs(self) -> int:
|
||||
"""The maximum number of sequences running in parallel in the remaining
|
||||
lifetime of the request."""
|
||||
if self.sampling_params.use_beam_search:
|
||||
# For beam search, maximally there will always be `best_of` beam
|
||||
# candidates running in the future.
|
||||
return self.sampling_params.best_of
|
||||
else:
|
||||
if self.sampling_params.best_of > self.num_seqs():
|
||||
# At prompt stage, the sequence group is not yet filled up
|
||||
# and only have one sequence running. However, in the
|
||||
# generation stage, we will have `best_of` sequences running.
|
||||
return self.sampling_params.best_of
|
||||
# At sampling stages, return the number of actual sequences
|
||||
# running.
|
||||
return self.num_seqs(status=SequenceStatus.RUNNING)
|
||||
|
||||
def get_seqs(
|
||||
self,
|
||||
status: Optional[SequenceStatus] = None,
|
||||
) -> List[Sequence]:
|
||||
if status is None:
|
||||
return self.seqs
|
||||
return list(self.seqs_dict.values())
|
||||
else:
|
||||
return [seq for seq in self.seqs if seq.status == status]
|
||||
return [
|
||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||
]
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
for seq in self.seqs:
|
||||
if seq.seq_id == seq_id:
|
||||
return seq
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
return self.seqs_dict[seq_id]
|
||||
|
||||
def add(self, seq: Sequence) -> None:
|
||||
if seq.seq_id in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
||||
self.seqs_dict[seq.seq_id] = seq
|
||||
|
||||
def remove(self, seq_id: int) -> None:
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
del self.seqs_dict[seq_id]
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
return all(seq.is_finished() for seq in self.get_seqs())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||
f"sampling_params={self.sampling_params}, "
|
||||
f"num_seqs={len(self.seqs)})")
|
||||
f"num_seqs={len(self.seqs_dict)})")
|
||||
|
||||
|
||||
class SequenceGroupMetadata:
|
||||
@@ -266,7 +321,6 @@ class SequenceOutputs:
|
||||
"""The model output associated with a sequence.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
parent_seq_id: The ID of the parent sequence (for forking in beam
|
||||
search).
|
||||
output_token: The output token ID.
|
||||
@@ -276,26 +330,27 @@ class SequenceOutputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
parent_seq_id: int,
|
||||
output_token: int,
|
||||
logprobs: Dict[int, float],
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.parent_seq_id = parent_seq_id
|
||||
self.output_token = output_token
|
||||
self.logprobs = logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceOutputs(seq_id={self.seq_id}, "
|
||||
f"parent_seq_id={self.parent_seq_id}, "
|
||||
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
|
||||
f"output_token={self.output_token}), "
|
||||
f"logprobs={self.logprobs}")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
return NotImplemented
|
||||
return (self.seq_id == other.seq_id
|
||||
and self.parent_seq_id == other.parent_seq_id
|
||||
return NotImplementedError()
|
||||
return (self.parent_seq_id == other.parent_seq_id
|
||||
and self.output_token == other.output_token
|
||||
and self.logprobs == other.logprobs)
|
||||
|
||||
|
||||
# For each sequence group, we generate a list of SequenceOutputs object,
|
||||
# each of which contains one possible candidate for the next token.
|
||||
SamplerOutput = List[List[SequenceOutputs]]
|
||||
|
||||
Reference in New Issue
Block a user