Use runtime profiling to replace manual memory analyzers (#81)

This commit is contained in:
Zhuohan Li
2023-05-19 11:35:44 -06:00
committed by GitHub
parent 825d8892b5
commit f756799b84
14 changed files with 211 additions and 478 deletions

View File

@@ -15,7 +15,7 @@ from cacheflow.core.server import (Server, add_server_arguments,
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
from cacheflow.utils import Counter
from cacheflow.worker.controller import DeviceID
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
@@ -34,6 +34,7 @@ class FastAPIServer:
dtype: str,
seed: int,
swap_space: int,
gpu_memory_utilization: float,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
@@ -41,6 +42,7 @@ class FastAPIServer:
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
server_use_ray: bool,
log_stats: bool,
):
self.block_size = block_size
@@ -62,15 +64,15 @@ class FastAPIServer:
dtype=dtype,
seed=seed,
swap_space=swap_space,
gpu_memory_utilization=gpu_memory_utilization,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=server_use_ray,
log_stats=log_stats,
)
self.running_seq_groups: Dict[int, SequenceGroup] = {}
@@ -182,6 +184,7 @@ if __name__ == "__main__":
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
gpu_memory_utilization=args.gpu_memory_utilization,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
@@ -189,6 +192,7 @@ if __name__ == "__main__":
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
server_use_ray=args.use_ray,
log_stats=args.log_stats,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")