Use runtime profiling to replace manual memory analyzers (#81)
This commit is contained in:
@@ -15,7 +15,7 @@ from cacheflow.core.server import (Server, add_server_arguments,
|
||||
from cacheflow.frontend.utils import get_tokenizer
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence, SequenceGroup
|
||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
||||
from cacheflow.utils import Counter
|
||||
from cacheflow.worker.controller import DeviceID
|
||||
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
@@ -34,6 +34,7 @@ class FastAPIServer:
|
||||
dtype: str,
|
||||
seed: int,
|
||||
swap_space: int,
|
||||
gpu_memory_utilization: float,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_sequences: int,
|
||||
num_nodes: int,
|
||||
@@ -41,6 +42,7 @@ class FastAPIServer:
|
||||
distributed_init_method: str,
|
||||
all_stage_devices: List[List[DeviceID]],
|
||||
server_use_ray: bool,
|
||||
log_stats: bool,
|
||||
):
|
||||
self.block_size = block_size
|
||||
|
||||
@@ -62,15 +64,15 @@ class FastAPIServer:
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
swap_space=swap_space,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_sequences=max_num_sequences,
|
||||
num_nodes=num_nodes,
|
||||
num_devices_per_node=num_devices_per_node,
|
||||
distributed_init_method=distributed_init_method,
|
||||
all_stage_devices=all_stage_devices,
|
||||
gpu_memory=get_gpu_memory(),
|
||||
cpu_memory=get_cpu_memory(),
|
||||
use_ray=server_use_ray,
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
self.running_seq_groups: Dict[int, SequenceGroup] = {}
|
||||
@@ -182,6 +184,7 @@ if __name__ == "__main__":
|
||||
dtype=args.dtype,
|
||||
seed=args.seed,
|
||||
swap_space=args.swap_space,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||
max_num_sequences=args.max_num_sequences,
|
||||
num_nodes=num_nodes,
|
||||
@@ -189,6 +192,7 @@ if __name__ == "__main__":
|
||||
distributed_init_method=distributed_init_method,
|
||||
all_stage_devices=all_stage_devices,
|
||||
server_use_ray=args.use_ray,
|
||||
log_stats=args.log_stats,
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
Reference in New Issue
Block a user