Rename variables and methods (#91)
This commit is contained in:
@@ -8,10 +8,11 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_all_reduce_launcher,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.sequence import SequenceGroupMetadata
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.worker.cache_engine import CacheEngine
|
||||
|
||||
|
||||
class Worker:
|
||||
|
||||
def __init__(
|
||||
@@ -93,30 +94,29 @@ class Worker:
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
input_seq_groups: List[SequenceGroupInputs],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
seq_logprobs: Dict[int, float] = {}
|
||||
sampling_params: Dict[int, SamplingParams] = {}
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
|
||||
# Add prompt tokens.
|
||||
prompt_lens: List[int] = []
|
||||
for input_seq_group in input_seq_groups:
|
||||
if not input_seq_group.is_prompt:
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
if not seq_group_metadata.is_prompt:
|
||||
continue
|
||||
|
||||
seq_ids = list(input_seq_group.input_tokens.keys())
|
||||
sampling_params = input_seq_group.sampling_params
|
||||
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
seq_logprobs.update(input_seq_group.seq_logprobs)
|
||||
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
||||
|
||||
# Use any sequence in the group.
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
prompt_tokens = input_seq_group.input_tokens[seq_id]
|
||||
prompt_tokens = seq_group_metadata.input_tokens[seq_id]
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
@@ -126,7 +126,7 @@ class Worker:
|
||||
input_positions.extend(range(len(prompt_tokens)))
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = input_seq_group.block_tables[seq_id]
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
for i in range(prompt_len):
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
@@ -138,31 +138,31 @@ class Worker:
|
||||
max_num_blocks_per_seq = 0
|
||||
context_lens: List[int] = []
|
||||
generation_block_tables: List[List[int]] = []
|
||||
for input_seq_group in input_seq_groups:
|
||||
if input_seq_group.is_prompt:
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
if seq_group_metadata.is_prompt:
|
||||
continue
|
||||
|
||||
seq_ids = list(input_seq_group.input_tokens.keys())
|
||||
sampling_params = input_seq_group.sampling_params
|
||||
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
seq_logprobs.update(input_seq_group.seq_logprobs)
|
||||
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
||||
|
||||
for seq_id in seq_ids:
|
||||
assert len(input_seq_group.input_tokens[seq_id]) == 1
|
||||
generation_token = input_seq_group.input_tokens[seq_id][0]
|
||||
assert len(seq_group_metadata.input_tokens[seq_id]) == 1
|
||||
generation_token = seq_group_metadata.input_tokens[seq_id][0]
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
position = input_seq_group.context_len - 1
|
||||
position = seq_group_metadata.context_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
block_table = input_seq_group.block_tables[seq_id]
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
max_context_len = max(
|
||||
max_context_len, input_seq_group.context_len)
|
||||
max_context_len, seq_group_metadata.context_len)
|
||||
max_num_blocks_per_seq = max(
|
||||
max_num_blocks_per_seq, len(block_table))
|
||||
context_lens.append(input_seq_group.context_len)
|
||||
context_lens.append(seq_group_metadata.context_len)
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
@@ -203,30 +203,30 @@ class Worker:
|
||||
@torch.inference_mode()
|
||||
def execute_stage(
|
||||
self,
|
||||
input_seq_groups: List[SequenceGroupInputs],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
# Issue cache operations.
|
||||
command_issued = False
|
||||
issued_cache_op = False
|
||||
if blocks_to_swap_in:
|
||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
||||
command_issued = True
|
||||
issued_cache_op = True
|
||||
if blocks_to_swap_out:
|
||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
||||
command_issued = True
|
||||
issued_cache_op = True
|
||||
if blocks_to_copy:
|
||||
self.cache_engine.copy(blocks_to_copy)
|
||||
command_issued = True
|
||||
issued_cache_op = True
|
||||
|
||||
if command_issued:
|
||||
if issued_cache_op:
|
||||
cache_events = self.cache_events
|
||||
else:
|
||||
cache_events = None
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if not input_seq_groups:
|
||||
if not seq_group_metadata_list:
|
||||
if cache_events is not None:
|
||||
for event in cache_events:
|
||||
event.wait()
|
||||
@@ -234,7 +234,7 @@ class Worker:
|
||||
|
||||
# Prepare input tensors.
|
||||
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
||||
input_seq_groups)
|
||||
seq_group_metadata_list)
|
||||
|
||||
# Execute the model.
|
||||
output = self.model(
|
||||
|
||||
Reference in New Issue
Block a user