Add output streaming support to multi-step + async while ensuring RequestOutput obj reuse (#8335)

This commit is contained in:
Alexander Matveev
2024-09-23 18:38:04 -04:00
committed by GitHub
parent 5f7bb58427
commit 1a2aef3e59
7 changed files with 142 additions and 42 deletions

View File

@@ -436,7 +436,7 @@ class Sequence:
self.stop_reason: Union[int, str, None] = None
# These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0
self._last_output_token_ids_offset: int = 0
self._last_output_text_offset: int = 0
# Used for incremental detokenization
@@ -499,18 +499,26 @@ class Sequence:
return self.output_text[last_offset:length]
return ""
def get_output_token_ids_to_return(self,
delta: bool) -> GenericSequence[int]:
def get_output_token_ids_to_return(
self, delta: bool) -> Union[GenericSequence[int], int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset
if last_offset < length:
self._last_token_ids_offset = length
return self.data._output_token_ids[last_offset:]
return ()
output_len = self.get_output_len()
# Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset
self._last_output_token_ids_offset = output_len
# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[-1]
return self.data._cached_all_token_ids[-num_new_tokens:]
def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
@@ -671,6 +679,8 @@ class SequenceGroup:
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
self.cached_request_output = None
@property
def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt.