Implement stop strings and best_of (#114)
This commit is contained in:
@@ -22,11 +22,18 @@ class SequenceData:
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
|
||||
self.output_token_ids: List[int] = []
|
||||
self.cumulative_logprobs = 0.0
|
||||
self.cumulative_logprob = 0.0
|
||||
|
||||
def append_token(self, token_id: int, logprob: float) -> None:
|
||||
self.output_token_ids.append(token_id)
|
||||
self.cumulative_logprob += logprob
|
||||
|
||||
def get_len(self) -> int:
|
||||
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return len(self.output_token_ids)
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
return self.prompt_token_ids + self.output_token_ids
|
||||
|
||||
@@ -37,9 +44,9 @@ class SequenceData:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceData("
|
||||
f"prompt={self.prompt}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"output_token_ids={self.output_token_ids})")
|
||||
f"output_token_ids={self.output_token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob})")
|
||||
|
||||
|
||||
class Sequence:
|
||||
@@ -57,6 +64,7 @@ class Sequence:
|
||||
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.output_text = ""
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
# Initialize the logical token blocks with the prompt token ids.
|
||||
@@ -88,18 +96,26 @@ class Sequence:
|
||||
assert token_id in logprobs
|
||||
self._append_tokens_to_blocks([token_id])
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.data.output_token_ids.append(token_id)
|
||||
self.data.cumulative_logprobs += logprobs[token_id]
|
||||
self.data.append_token(token_id, logprobs[token_id])
|
||||
|
||||
def get_len(self) -> int:
|
||||
return self.data.get_len()
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return self.data.get_output_len()
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
return self.data.get_token_ids()
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
return self.data.get_last_token_id()
|
||||
|
||||
def get_output_token_ids(self) -> List[int]:
|
||||
return self.data.output_token_ids
|
||||
|
||||
def get_cumulative_logprob(self) -> float:
|
||||
return self.data.cumulative_logprob
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
|
||||
Reference in New Issue
Block a user