Add an option to launch cacheflow without ray (#51)
This commit is contained in:
@@ -13,7 +13,8 @@ import uvicorn
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence, SequenceGroup
|
||||
from cacheflow.master.server import (Server, add_server_arguments,
|
||||
initialize_ray_cluster)
|
||||
process_server_arguments,
|
||||
initialize_cluster)
|
||||
from cacheflow.worker.controller import DeviceID
|
||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
||||
|
||||
@@ -33,17 +34,22 @@ class FastAPIFrontend:
|
||||
seed: int,
|
||||
swap_space: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_sequences: int,
|
||||
num_nodes: int,
|
||||
num_devices_per_node: int,
|
||||
distributed_init_method: str,
|
||||
all_stage_devices: List[List[DeviceID]],
|
||||
server_use_ray: bool,
|
||||
):
|
||||
self.block_size = block_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
self.seq_group_counter = Counter()
|
||||
self.seq_counter = Counter()
|
||||
remote_server_class = ray.remote(num_cpus=0)(Server)
|
||||
if server_use_ray:
|
||||
remote_server_class = ray.remote(num_cpus=0)(Server)
|
||||
else:
|
||||
remote_server_class = ray.remote(num_gpus=1)(Server)
|
||||
self.server = remote_server_class.remote(
|
||||
model=model,
|
||||
model_path=model_path,
|
||||
@@ -55,12 +61,14 @@ class FastAPIFrontend:
|
||||
seed=seed,
|
||||
swap_space=swap_space,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_sequences=max_num_sequences,
|
||||
num_nodes=num_nodes,
|
||||
num_devices_per_node=num_devices_per_node,
|
||||
distributed_init_method=distributed_init_method,
|
||||
all_stage_devices=all_stage_devices,
|
||||
gpu_memory=get_gpu_memory(),
|
||||
cpu_memory=get_cpu_memory(),
|
||||
use_ray=server_use_ray,
|
||||
)
|
||||
|
||||
self.running_seq_groups: Dict[int, SequenceGroup] = {}
|
||||
@@ -149,6 +157,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--port", type=int, default=10002)
|
||||
parser = add_server_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args = process_server_arguments(args)
|
||||
|
||||
# TODO(zhuohan): Support pipeline parallelism.
|
||||
assert args.pipeline_parallel_size == 1, (
|
||||
@@ -156,7 +165,8 @@ if __name__ == "__main__":
|
||||
|
||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
||||
all_stage_devices) = (
|
||||
initialize_ray_cluster(
|
||||
initialize_cluster(
|
||||
use_ray=True,
|
||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||
tensor_parallel_size=args.tensor_parallel_size))
|
||||
|
||||
@@ -170,10 +180,12 @@ if __name__ == "__main__":
|
||||
seed=args.seed,
|
||||
swap_space=args.swap_space,
|
||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||
max_num_sequences=args.max_num_sequences,
|
||||
num_nodes=num_nodes,
|
||||
num_devices_per_node=num_devices_per_node,
|
||||
distributed_init_method=distributed_init_method,
|
||||
all_stage_devices=all_stage_devices,
|
||||
server_use_ray=args.use_ray,
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
Reference in New Issue
Block a user