Support tensor parallel (#2)
This commit is contained in:
@@ -1,45 +1,62 @@
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Union, Tuple
|
||||
|
||||
import ray
|
||||
|
||||
from cacheflow.master.scheduler import Scheduler
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.worker.worker import Worker
|
||||
|
||||
|
||||
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
|
||||
|
||||
|
||||
class Controller:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: int,
|
||||
num_workers: int,
|
||||
stage_id: int,
|
||||
stage_devices: List[DeviceID],
|
||||
world_size: int,
|
||||
tensor_parallel_size: int,
|
||||
pipeline_parallel_size: int,
|
||||
distributed_init_method: str,
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
model_path: str,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self.num_workers = num_workers
|
||||
self.stage_id = stage_id
|
||||
self.stage_devices = stage_devices
|
||||
self.model_name = model_name
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Which pipeline stage is this node assigned to?
|
||||
self.is_first_stage = node_id == 0
|
||||
self.is_first_stage = stage_id == 0
|
||||
self.is_last_stage = False
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
for i in range(num_workers):
|
||||
worker = Worker(
|
||||
worker_id=node_id + i,
|
||||
gpu_id=i,
|
||||
for rank, node_resource, device_id in stage_devices:
|
||||
worker_cls = ray.remote(num_cpus=0,
|
||||
num_gpus=1,
|
||||
resources={node_resource: 1e-5})(Worker)
|
||||
worker = worker_cls.remote(
|
||||
model_name=model_name,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
distributed_init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
model_path=model_path,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
@@ -57,15 +74,21 @@ class Controller:
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
# FIXME: Support tensor parallelism.
|
||||
assert len(self.workers) == 1
|
||||
worker = self.workers[0]
|
||||
output = worker.execute_stage(
|
||||
input_seq_groups,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
)
|
||||
futures = []
|
||||
for worker in self.workers:
|
||||
future = worker.execute_stage.remote(
|
||||
input_seq_groups,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
all_outputs = ray.get(futures)
|
||||
# Make sure all workers have the same results.
|
||||
output = all_outputs[0]
|
||||
for other_output in all_outputs[1:]:
|
||||
assert output == other_output
|
||||
|
||||
if self.is_last_stage:
|
||||
self.next_node.post_step(output)
|
||||
|
||||
Reference in New Issue
Block a user