Implement stop strings and best_of (#114)
This commit is contained in:
@@ -73,8 +73,6 @@ class Scheduler:
|
||||
self.waiting: List[SequenceGroup] = []
|
||||
# Sequence groups in the RUNNING state.
|
||||
self.running: List[SequenceGroup] = []
|
||||
# Mapping: request_id -> num_steps.
|
||||
self.num_steps: Dict[str, int] = {}
|
||||
# Sequence groups in the SWAPPED state.
|
||||
self.swapped: List[SequenceGroup] = []
|
||||
|
||||
@@ -84,7 +82,6 @@ class Scheduler:
|
||||
|
||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
# Add sequence groups to the waiting queue.
|
||||
assert seq_group.request_id not in self.num_steps
|
||||
self.waiting.append(seq_group)
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
@@ -178,7 +175,7 @@ class Scheduler:
|
||||
break
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||
if (num_batched_tokens + num_prompt_tokens
|
||||
> self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
@@ -278,15 +275,8 @@ class Scheduler:
|
||||
) -> List[SequenceGroup]:
|
||||
# Update the running sequences and free blocks.
|
||||
for seq_group in self.running:
|
||||
request_id = seq_group.request_id
|
||||
self.num_steps[request_id] += 1
|
||||
stop_token_ids = seq_group.sampling_params.stop_token_ids
|
||||
|
||||
# Process beam search results before processing the next tokens.
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
continue
|
||||
|
||||
# Process beam search results before processing the new tokens.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
output = seq_outputs[seq.seq_id]
|
||||
if seq.seq_id != output.parent_seq_id:
|
||||
# The sequence is a fork of the parent sequence (beam search).
|
||||
@@ -297,43 +287,27 @@ class Scheduler:
|
||||
parent_seq.fork(seq)
|
||||
self.block_manager.fork(parent_seq, seq)
|
||||
|
||||
# Process the next tokens.
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
continue
|
||||
|
||||
# Process the new tokens.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
# Append a new token to the sequence.
|
||||
output = seq_outputs[seq.seq_id]
|
||||
seq.append_token(output.output_token, output.logprobs)
|
||||
return self.running.copy()
|
||||
|
||||
# Check if the sequence has generated a stop token.
|
||||
if output.output_token in stop_token_ids:
|
||||
self._free_seq(seq)
|
||||
continue
|
||||
def free_seq(self, seq: Sequence) -> None:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.free(seq)
|
||||
|
||||
# Check if the sequence has reached the maximum number of steps.
|
||||
max_num_steps = seq_group.sampling_params.max_tokens
|
||||
if self.num_steps[request_id] == max_num_steps:
|
||||
self._free_seq(seq)
|
||||
continue
|
||||
|
||||
# Update the running sequences.
|
||||
updated = self.running.copy()
|
||||
running: List[SequenceGroup] = []
|
||||
for seq_group in self.running:
|
||||
if seq_group.is_finished():
|
||||
self._free_seq_group(seq_group)
|
||||
else:
|
||||
running.append(seq_group)
|
||||
self.running = running
|
||||
return updated
|
||||
def free_finished_seq_groups(self) -> None:
|
||||
self.running = [
|
||||
seq_group for seq_group in self.running
|
||||
if not seq_group.is_finished()
|
||||
]
|
||||
|
||||
def _allocate(self, seq_group: SequenceGroup) -> None:
|
||||
self.block_manager.allocate(seq_group)
|
||||
for seq in seq_group.seqs:
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
if seq_group.request_id not in self.num_steps:
|
||||
self.num_steps[seq_group.request_id] = 0
|
||||
|
||||
def _append_slot(
|
||||
self,
|
||||
@@ -403,13 +377,6 @@ class Scheduler:
|
||||
self._swap_out(seq_group, blocks_to_swap_out)
|
||||
self.swapped.append(seq_group)
|
||||
|
||||
def _free_seq(self, seq: Sequence) -> None:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.free(seq)
|
||||
|
||||
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
del self.num_steps[seq_group.request_id]
|
||||
|
||||
def _swap_in(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
|
||||
Reference in New Issue
Block a user