Rename variables and methods (#91)
This commit is contained in:
@@ -18,45 +18,46 @@ class Sequence:
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
token_ids: List[int],
|
||||
prompt_token_ids: List[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.block_size = block_size
|
||||
|
||||
self.prompt_len = len(prompt_token_ids)
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
# Initialize the logical token blocks with the given token ids.
|
||||
self.add(token_ids)
|
||||
# Initialize the logical token blocks with the prompt token ids.
|
||||
self._append_tokens(prompt_token_ids)
|
||||
|
||||
self.prompt_len = len(token_ids)
|
||||
self.status = SequenceStatus.WAITING
|
||||
# Used for beam search.
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.cumulative_logprobs = 0.0
|
||||
|
||||
def add_block(self) -> None:
|
||||
def _append_logical_block(self) -> None:
|
||||
block = LogicalTokenBlock(
|
||||
block_number=len(self.logical_token_blocks),
|
||||
block_size=self.block_size,
|
||||
)
|
||||
self.logical_token_blocks.append(block)
|
||||
|
||||
def add(self, token_ids: List[int]) -> None:
|
||||
def _append_tokens(self, token_ids: List[int]) -> None:
|
||||
while token_ids:
|
||||
if not self.logical_token_blocks:
|
||||
self.add_block()
|
||||
self._append_logical_block()
|
||||
|
||||
last_block = self.logical_token_blocks[-1]
|
||||
if last_block.is_full():
|
||||
self.add_block()
|
||||
self._append_logical_block()
|
||||
last_block = self.logical_token_blocks[-1]
|
||||
|
||||
num_empty_slots = last_block.get_num_empty_slots()
|
||||
last_block.append(token_ids[:num_empty_slots])
|
||||
last_block.append_tokens(token_ids[:num_empty_slots])
|
||||
token_ids = token_ids[num_empty_slots:]
|
||||
|
||||
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
||||
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
||||
assert token_id in logprobs
|
||||
self.add([token_id])
|
||||
self._append_tokens([token_id])
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.cumulative_logprobs += logprobs[token_id]
|
||||
|
||||
@@ -121,7 +122,7 @@ class SequenceGroup:
|
||||
f'num_seqs={len(self.seqs)})')
|
||||
|
||||
|
||||
class SequenceGroupInputs:
|
||||
class SequenceGroupMetadata:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user