Rename variables and methods (#91)

This commit is contained in:
Woosuk Kwon
2023-05-10 00:58:31 -07:00
committed by GitHub
parent ce26e57fd3
commit 8d66a7b6d7
7 changed files with 64 additions and 83 deletions

View File

@@ -18,45 +18,46 @@ class Sequence:
def __init__(
self,
seq_id: int,
token_ids: List[int],
prompt_token_ids: List[int],
block_size: int,
) -> None:
self.seq_id = seq_id
self.block_size = block_size
self.prompt_len = len(prompt_token_ids)
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the given token ids.
self.add(token_ids)
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens(prompt_token_ids)
self.prompt_len = len(token_ids)
self.status = SequenceStatus.WAITING
# Used for beam search.
self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 0.0
def add_block(self) -> None:
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
block_size=self.block_size,
)
self.logical_token_blocks.append(block)
def add(self, token_ids: List[int]) -> None:
def _append_tokens(self, token_ids: List[int]) -> None:
while token_ids:
if not self.logical_token_blocks:
self.add_block()
self._append_logical_block()
last_block = self.logical_token_blocks[-1]
if last_block.is_full():
self.add_block()
self._append_logical_block()
last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots()
last_block.append(token_ids[:num_empty_slots])
last_block.append_tokens(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:]
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
assert token_id in logprobs
self.add([token_id])
self._append_tokens([token_id])
self.output_logprobs.append(logprobs)
self.cumulative_logprobs += logprobs[token_id]
@@ -121,7 +122,7 @@ class SequenceGroup:
f'num_seqs={len(self.seqs)})')
class SequenceGroupInputs:
class SequenceGroupMetadata:
def __init__(
self,