Incrementally decode output tokens (#121)
This commit is contained in:
@@ -14,7 +14,8 @@ from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
||||
detokenize_incrementally)
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from cacheflow.utils import Counter
|
||||
from cacheflow.worker.worker import Worker
|
||||
@@ -185,18 +186,17 @@ class LLMServer:
|
||||
return request_outputs
|
||||
|
||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Batch-decode the sequence outputs.
|
||||
seqs: List[Sequence] = []
|
||||
# Decode the sequence outputs.
|
||||
for seq_group in seq_groups:
|
||||
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING))
|
||||
output_tokens_per_seq = []
|
||||
for seq in seqs:
|
||||
output_tokens_per_seq.append(seq.get_output_token_ids())
|
||||
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq,
|
||||
skip_special_tokens=True)
|
||||
# Update the sequences with the output texts.
|
||||
for seq, output_text in zip(seqs, output_texts):
|
||||
seq.output_text = output_text
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
seq.output_tokens,
|
||||
seq.get_last_token_id(),
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
seq.output_tokens.append(new_token)
|
||||
seq.output_text = new_output_text
|
||||
|
||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Stop the sequences.
|
||||
|
||||
Reference in New Issue
Block a user