Implement stop strings and best_of (#114)
This commit is contained in:
@@ -13,7 +13,7 @@ from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.sequence import Sequence, SequenceGroup
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from cacheflow.utils import Counter
|
||||
from cacheflow.worker.worker import Worker
|
||||
|
||||
@@ -49,7 +49,6 @@ class LLMServer:
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
self._verify_args()
|
||||
|
||||
self.tokenizer = get_tokenizer(model_config.model)
|
||||
@@ -124,15 +123,11 @@ class LLMServer:
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.n):
|
||||
for _ in range(sampling_params.best_of):
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
seqs.append(seq)
|
||||
|
||||
# FIXME(woosuk)
|
||||
# Add the EOS token to the stop token list.
|
||||
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, seqs, sampling_params,
|
||||
arrival_time)
|
||||
@@ -157,18 +152,65 @@ class LLMServer:
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
# Update the scheduler.
|
||||
updated_seq_groups = self.scheduler.update(output)
|
||||
# Update the scheduler with the model outputs.
|
||||
seq_groups = self.scheduler.update(output)
|
||||
|
||||
# Decode the sequences.
|
||||
self._decode_sequences(seq_groups)
|
||||
# Stop the sequences that meet the stopping criteria.
|
||||
self._stop_sequences(seq_groups)
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
for seq_group in updated_seq_groups:
|
||||
# TODO(woosuk): Batch-decode the outputs for speedup.
|
||||
request_output = RequestOutput.from_seq_group(seq_group,
|
||||
self.tokenizer)
|
||||
for seq_group in seq_groups:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
return request_outputs
|
||||
|
||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Batch-decode the sequence outputs.
|
||||
seqs: List[Sequence] = []
|
||||
for seq_group in seq_groups:
|
||||
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING))
|
||||
output_tokens_per_seq = []
|
||||
for seq in seqs:
|
||||
output_tokens_per_seq.append(seq.get_output_token_ids())
|
||||
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq,
|
||||
skip_special_tokens=True)
|
||||
# Update the sequences with the output texts.
|
||||
for seq, output_text in zip(seqs, output_texts):
|
||||
seq.output_text = output_text
|
||||
|
||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Stop the sequences.
|
||||
for seq_group in seq_groups:
|
||||
sampling_params = seq_group.sampling_params
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
# Check if the sequence has generated a stop string.
|
||||
stopped = False
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
self.scheduler.free_seq(seq)
|
||||
stopped = True
|
||||
break
|
||||
if stopped:
|
||||
continue
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
self.scheduler.free_seq(seq)
|
||||
continue
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if not sampling_params.ignore_eos:
|
||||
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
||||
self.scheduler.free_seq(seq)
|
||||
continue
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
|
||||
Reference in New Issue
Block a user