Implement stop strings and best_of (#114)

This commit is contained in:
Woosuk Kwon
2023-05-21 11:18:00 -07:00
committed by GitHub
parent c3442c1f6f
commit f746ced08d
9 changed files with 162 additions and 116 deletions

View File

@@ -22,11 +22,18 @@ class SequenceData:
self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = []
self.cumulative_logprobs = 0.0
self.cumulative_logprob = 0.0
def append_token(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
self.cumulative_logprob += logprob
def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids)
def get_output_len(self) -> int:
return len(self.output_token_ids)
def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
@@ -37,9 +44,9 @@ class SequenceData:
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt={self.prompt}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_token_ids={self.output_token_ids})")
f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob})")
class Sequence:
@@ -57,6 +64,7 @@ class Sequence:
self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = []
self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
@@ -88,18 +96,26 @@ class Sequence:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.output_token_ids.append(token_id)
self.data.cumulative_logprobs += logprobs[token_id]
self.data.append_token(token_id, logprobs[token_id])
def get_len(self) -> int:
return self.data.get_len()
def get_output_len(self) -> int:
return self.data.get_output_len()
def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()
def get_output_token_ids(self) -> List[int]:
return self.data.output_token_ids
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob
def fork(self, child_seq: 'Sequence') -> 'Sequence':
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)