Rename variables and methods (#91)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Union, Tuple, Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
try:
|
||||
import ray
|
||||
@@ -6,7 +6,6 @@ except ImportError:
|
||||
ray = None
|
||||
|
||||
from cacheflow.core.scheduler import Scheduler
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.worker.worker import Worker
|
||||
|
||||
|
||||
@@ -81,23 +80,12 @@ class Controller:
|
||||
self.next_node = next_node
|
||||
self.is_last_stage = isinstance(next_node, Scheduler)
|
||||
|
||||
def execute_stage(
|
||||
self,
|
||||
input_seq_groups: List[SequenceGroupInputs],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
def execute_stage(self, *args, **kwargs) -> None:
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
executor = (worker.execute_stage.remote
|
||||
if self.use_ray else worker.execute_stage)
|
||||
output = executor(
|
||||
input_seq_groups,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
)
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
if self.use_ray:
|
||||
|
||||
Reference in New Issue
Block a user