Implement presence and frequency penalties (#95)
This commit is contained in:
@@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
|
||||
FINISHED = enum.auto()
|
||||
|
||||
|
||||
class SequenceData:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
) -> None:
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
|
||||
self.output_token_ids: List[int] = []
|
||||
self.cumulative_logprobs = 0.0
|
||||
|
||||
def get_len(self) -> int:
|
||||
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
return self.prompt_token_ids + self.output_token_ids
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
if not self.output_token_ids:
|
||||
return self.prompt_token_ids[-1]
|
||||
return self.output_token_ids[-1]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceData("
|
||||
f"prompt={self.prompt}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"output_token_ids={self.output_token_ids})")
|
||||
|
||||
|
||||
class Sequence:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.prompt = prompt
|
||||
self.block_size = block_size
|
||||
|
||||
self.prompt_len = len(prompt_token_ids)
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
# Initialize the logical token blocks with the prompt token ids.
|
||||
self._append_tokens(prompt_token_ids)
|
||||
|
||||
self._append_tokens_to_blocks(prompt_token_ids)
|
||||
self.status = SequenceStatus.WAITING
|
||||
# Used for beam search.
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.cumulative_logprobs = 0.0
|
||||
|
||||
def _append_logical_block(self) -> None:
|
||||
block = LogicalTokenBlock(
|
||||
@@ -41,7 +70,7 @@ class Sequence:
|
||||
)
|
||||
self.logical_token_blocks.append(block)
|
||||
|
||||
def _append_tokens(self, token_ids: List[int]) -> None:
|
||||
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
|
||||
while token_ids:
|
||||
if not self.logical_token_blocks:
|
||||
self._append_logical_block()
|
||||
@@ -57,26 +86,24 @@ class Sequence:
|
||||
|
||||
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
||||
assert token_id in logprobs
|
||||
self._append_tokens([token_id])
|
||||
self._append_tokens_to_blocks([token_id])
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.cumulative_logprobs += logprobs[token_id]
|
||||
self.data.output_token_ids.append(token_id)
|
||||
self.data.cumulative_logprobs += logprobs[token_id]
|
||||
|
||||
def get_len(self) -> int:
|
||||
return sum(block.num_tokens for block in self.logical_token_blocks)
|
||||
return self.data.get_len()
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
token_ids: List[int] = []
|
||||
for block in self.logical_token_blocks:
|
||||
token_ids.extend(block.get_token_ids())
|
||||
return token_ids
|
||||
return self.data.get_token_ids()
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
return self.logical_token_blocks[-1].get_last_token_id()
|
||||
return self.data.get_last_token_id()
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.cumulative_logprobs = self.cumulative_logprobs
|
||||
child_seq.data = copy.deepcopy(self.data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'Sequence(seq_id={self.seq_id}, '
|
||||
@@ -128,17 +155,13 @@ class SequenceGroupMetadata:
|
||||
self,
|
||||
group_id: int,
|
||||
is_prompt: bool,
|
||||
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
|
||||
context_len: int,
|
||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
|
||||
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
|
||||
) -> None:
|
||||
self.group_id = group_id
|
||||
self.is_prompt = is_prompt
|
||||
self.input_tokens = input_tokens
|
||||
self.context_len = context_len
|
||||
self.seq_logprobs = seq_logprobs
|
||||
self.seq_data = seq_data
|
||||
self.sampling_params = sampling_params
|
||||
self.block_tables = block_tables
|
||||
|
||||
|
||||
Reference in New Issue
Block a user