Incrementally decode output tokens (#121)
This commit is contained in:
@@ -24,7 +24,7 @@ class SequenceData:
|
||||
self.output_token_ids: List[int] = []
|
||||
self.cumulative_logprob = 0.0
|
||||
|
||||
def append_token(self, token_id: int, logprob: float) -> None:
|
||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||
self.output_token_ids.append(token_id)
|
||||
self.cumulative_logprob += logprob
|
||||
|
||||
@@ -64,6 +64,7 @@ class Sequence:
|
||||
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.output_tokens: List[str] = []
|
||||
self.output_text = ""
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
@@ -92,11 +93,15 @@ class Sequence:
|
||||
last_block.append_tokens(token_ids[:num_empty_slots])
|
||||
token_ids = token_ids[num_empty_slots:]
|
||||
|
||||
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
||||
def append_token_id(
|
||||
self,
|
||||
token_id: int,
|
||||
logprobs: Dict[int, float],
|
||||
) -> None:
|
||||
assert token_id in logprobs
|
||||
self._append_tokens_to_blocks([token_id])
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.data.append_token(token_id, logprobs[token_id])
|
||||
self.data.append_token_id(token_id, logprobs[token_id])
|
||||
|
||||
def get_len(self) -> int:
|
||||
return self.data.get_len()
|
||||
|
||||
Reference in New Issue
Block a user