Implement presence and frequency penalties (#95)

This commit is contained in:
Woosuk Kwon
2023-05-10 23:39:12 -07:00
committed by GitHub
parent 9f88db35da
commit 55f8b0a5de
9 changed files with 215 additions and 82 deletions

View File

@@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
FINISHED = enum.auto()
class SequenceData:
def __init__(
self,
prompt_token_ids: List[int],
) -> None:
self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = []
self.cumulative_logprobs = 0.0
def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids)
def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
def get_last_token_id(self) -> int:
if not self.output_token_ids:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt={self.prompt}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_token_ids={self.output_token_ids})")
class Sequence:
def __init__(
self,
seq_id: int,
prompt: str,
prompt_token_ids: List[int],
block_size: int,
) -> None:
self.seq_id = seq_id
self.prompt = prompt
self.block_size = block_size
self.prompt_len = len(prompt_token_ids)
self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = []
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens(prompt_token_ids)
self._append_tokens_to_blocks(prompt_token_ids)
self.status = SequenceStatus.WAITING
# Used for beam search.
self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 0.0
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
@@ -41,7 +70,7 @@ class Sequence:
)
self.logical_token_blocks.append(block)
def _append_tokens(self, token_ids: List[int]) -> None:
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
while token_ids:
if not self.logical_token_blocks:
self._append_logical_block()
@@ -57,26 +86,24 @@ class Sequence:
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
assert token_id in logprobs
self._append_tokens([token_id])
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.cumulative_logprobs += logprobs[token_id]
self.data.output_token_ids.append(token_id)
self.data.cumulative_logprobs += logprobs[token_id]
def get_len(self) -> int:
return sum(block.num_tokens for block in self.logical_token_blocks)
return self.data.get_len()
def get_token_ids(self) -> List[int]:
token_ids: List[int] = []
for block in self.logical_token_blocks:
token_ids.extend(block.get_token_ids())
return token_ids
return self.data.get_token_ids()
def get_last_token_id(self) -> int:
return self.logical_token_blocks[-1].get_last_token_id()
return self.data.get_last_token_id()
def fork(self, child_seq: 'Sequence') -> 'Sequence':
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.cumulative_logprobs = self.cumulative_logprobs
child_seq.data = copy.deepcopy(self.data)
def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, '
@@ -128,17 +155,13 @@ class SequenceGroupMetadata:
self,
group_id: int,
is_prompt: bool,
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
context_len: int,
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
) -> None:
self.group_id = group_id
self.is_prompt = is_prompt
self.input_tokens = input_tokens
self.context_len = context_len
self.seq_logprobs = seq_logprobs
self.seq_data = seq_data
self.sampling_params = sampling_params
self.block_tables = block_tables