Rename variables and methods (#91)

This commit is contained in:
Woosuk Kwon
2023-05-10 00:58:31 -07:00
committed by GitHub
parent ce26e57fd3
commit 8d66a7b6d7
7 changed files with 64 additions and 83 deletions

View File

@@ -9,7 +9,7 @@ from cacheflow.core.policy import PolicyFactory
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceGroupMetadata
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus
@@ -105,7 +105,7 @@ class Scheduler:
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.pop(0)
while not self.block_manager.can_append(seq_group):
while not self.block_manager.can_append_slot(seq_group):
if self.running:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1)
@@ -119,7 +119,7 @@ class Scheduler:
break
else:
# Append new slots to the sequence group.
self._append(seq_group, blocks_to_copy)
self._append_slot(seq_group, blocks_to_copy)
running.append(seq_group)
self.running = running
@@ -143,7 +143,7 @@ class Scheduler:
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group, blocks_to_copy)
self._append_slot(seq_group, blocks_to_copy)
self.running.append(seq_group)
num_batched_tokens = sum(
@@ -252,7 +252,7 @@ class Scheduler:
prompt_group_ids = scheduler_output[3]
# Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running:
@@ -274,7 +274,7 @@ class Scheduler:
# sequence length
seq_len = seq.get_len()
input_seq_group = SequenceGroupInputs(
seq_group_metadata = SequenceGroupMetadata(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
@@ -283,14 +283,14 @@ class Scheduler:
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)
seq_group_metadata_list.append(seq_group_metadata)
# Execute the first stage of the pipeline.
if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.controllers[0].execute_stage(
input_seq_groups,
seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
@@ -330,7 +330,7 @@ class Scheduler:
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)
seq.append_token(output.output_token, output.logprobs)
# Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids:
@@ -360,13 +360,13 @@ class Scheduler:
if seq_group.group_id not in self.num_steps:
self.num_steps[seq_group.group_id] = 0
def _append(
def _append_slot(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
ret = self.block_manager.append(seq)
ret = self.block_manager.append_slot(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy: