Add an option to launch cacheflow without ray (#51)

This commit is contained in:
Zhuohan Li
2023-04-30 15:42:17 +08:00
committed by GitHub
parent a96d63c21d
commit 4858f3bb45
7 changed files with 102 additions and 28 deletions

View File

@@ -11,7 +11,8 @@ from transformers import AutoConfig
from benchmark.trace import generate_text_completion_requests
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
process_server_arguments,
initialize_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
@@ -25,8 +26,8 @@ def main(args: argparse.Namespace):
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
address='local',
initialize_cluster(
use_ray=args.use_ray,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
@@ -49,6 +50,7 @@ def main(args: argparse.Namespace):
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
collect_stats=True,
do_memory_analysis=args.do_memory_analysis,
)
@@ -134,7 +136,7 @@ def main(args: argparse.Namespace):
finished.append({
'group_id': seq_group.group_id,
'seq_id': seq.seq_id,
'arrival_time': arrival_time,
'arrival_time': arrival_time,
'finish_time': finish_time,
'prompt_len': seq.prompt_len,
'output_len': output_len,
@@ -225,8 +227,9 @@ def get_sampling_dir_name(
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
parser = argparse.ArgumentParser(
description='Benchmark the performance on a series of requests.')
parser = add_server_arguments(parser)
parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)
parser.add_argument('--dataset', type=str, help='path to dataset', required=True)
@@ -246,6 +249,7 @@ if __name__ == '__main__':
parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0)
parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0)
args = parser.parse_args()
args = process_server_arguments(args)
if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0:
raise ValueError('The ratios of requests must sum to 1.')