Implement stop strings and best_of (#114)

This commit is contained in:
Woosuk Kwon
2023-05-21 11:18:00 -07:00
committed by GitHub
parent c3442c1f6f
commit f746ced08d
9 changed files with 162 additions and 116 deletions

View File

@@ -73,8 +73,6 @@ class Scheduler:
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
# Mapping: request_id -> num_steps.
self.num_steps: Dict[str, int] = {}
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []
@@ -84,7 +82,6 @@ class Scheduler:
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
assert seq_group.request_id not in self.num_steps
self.waiting.append(seq_group)
def has_unfinished_seqs(self) -> bool:
@@ -178,7 +175,7 @@ class Scheduler:
break
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.seqs[0].get_len()
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if (num_batched_tokens + num_prompt_tokens
> self.scheduler_config.max_num_batched_tokens):
break
@@ -278,15 +275,8 @@ class Scheduler:
) -> List[SequenceGroup]:
# Update the running sequences and free blocks.
for seq_group in self.running:
request_id = seq_group.request_id
self.num_steps[request_id] += 1
stop_token_ids = seq_group.sampling_params.stop_token_ids
# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
@@ -297,43 +287,27 @@ class Scheduler:
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)
# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
# Process the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token(output.output_token, output.logprobs)
return self.running.copy()
# Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids:
self._free_seq(seq)
continue
def free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
# Check if the sequence has reached the maximum number of steps.
max_num_steps = seq_group.sampling_params.max_tokens
if self.num_steps[request_id] == max_num_steps:
self._free_seq(seq)
continue
# Update the running sequences.
updated = self.running.copy()
running: List[SequenceGroup] = []
for seq_group in self.running:
if seq_group.is_finished():
self._free_seq_group(seq_group)
else:
running.append(seq_group)
self.running = running
return updated
def free_finished_seq_groups(self) -> None:
self.running = [
seq_group for seq_group in self.running
if not seq_group.is_finished()
]
def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.RUNNING
if seq_group.request_id not in self.num_steps:
self.num_steps[seq_group.request_id] = 0
def _append_slot(
self,
@@ -403,13 +377,6 @@ class Scheduler:
self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group)
def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
del self.num_steps[seq_group.request_id]
def _swap_in(
self,
seq_group: SequenceGroup,