Incrementally decode output tokens (#121)

This commit is contained in:
Woosuk Kwon
2023-05-23 20:46:32 -07:00
committed by GitHub
parent aedba6d5ec
commit e86717833d
4 changed files with 83 additions and 17 deletions

View File

@@ -14,7 +14,8 @@ from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker
@@ -185,18 +186,17 @@ class LLMServer:
return request_outputs
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Batch-decode the sequence outputs.
seqs: List[Sequence] = []
# Decode the sequence outputs.
for seq_group in seq_groups:
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING))
output_tokens_per_seq = []
for seq in seqs:
output_tokens_per_seq.append(seq.get_output_token_ids())
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq,
skip_special_tokens=True)
# Update the sequences with the output texts.
for seq, output_text in zip(seqs, output_texts):
seq.output_text = output_text
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
seq.get_last_token_id(),
skip_special_tokens=True,
)
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Stop the sequences.