Incrementally decode output tokens (#121)

This commit is contained in:
Woosuk Kwon
2023-05-23 20:46:32 -07:00
committed by GitHub
parent aedba6d5ec
commit e86717833d
4 changed files with 83 additions and 17 deletions

View File

@@ -24,7 +24,7 @@ class SequenceData:
self.output_token_ids: List[int] = []
self.cumulative_logprob = 0.0
def append_token(self, token_id: int, logprob: float) -> None:
def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
self.cumulative_logprob += logprob
@@ -64,6 +64,7 @@ class Sequence:
self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = []
self.output_tokens: List[str] = []
self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = []
@@ -92,11 +93,15 @@ class Sequence:
last_block.append_tokens(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:]
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, float],
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.append_token(token_id, logprobs[token_id])
self.data.append_token_id(token_id, logprobs[token_id])
def get_len(self) -> int:
return self.data.get_len()