Implement stop strings and best_of (#114)

This commit is contained in:
Woosuk Kwon
2023-05-21 11:18:00 -07:00
committed by GitHub
parent c3442c1f6f
commit f746ced08d
9 changed files with 162 additions and 116 deletions

View File

@@ -13,7 +13,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker
@@ -49,7 +49,6 @@ class LLMServer:
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(model_config.model)
@@ -124,15 +123,11 @@ class LLMServer:
# Create the sequences.
block_size = self.cache_config.block_size
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
for _ in range(sampling_params.best_of):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seqs.append(seq)
# FIXME(woosuk)
# Add the EOS token to the stop token list.
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
# Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params,
arrival_time)
@@ -157,18 +152,65 @@ class LLMServer:
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
# Update the scheduler.
updated_seq_groups = self.scheduler.update(output)
# Update the scheduler with the model outputs.
seq_groups = self.scheduler.update(output)
# Decode the sequences.
self._decode_sequences(seq_groups)
# Stop the sequences that meet the stopping criteria.
self._stop_sequences(seq_groups)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in updated_seq_groups:
# TODO(woosuk): Batch-decode the outputs for speedup.
request_output = RequestOutput.from_seq_group(seq_group,
self.tokenizer)
for seq_group in seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
return request_outputs
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Batch-decode the sequence outputs.
seqs: List[Sequence] = []
for seq_group in seq_groups:
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING))
output_tokens_per_seq = []
for seq in seqs:
output_tokens_per_seq.append(seq.get_output_token_ids())
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq,
skip_special_tokens=True)
# Update the sequences with the output texts.
for seq, output_text in zip(seqs, output_texts):
seq.output_text = output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Stop the sequences.
for seq_group in seq_groups:
sampling_params = seq_group.sampling_params
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Check if the sequence has generated a stop string.
stopped = False
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq)
stopped = True
break
if stopped:
continue
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq(seq)
continue
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(seq)
continue
def _run_workers(
self,
method: str,