FastAPI-based working frontend (#10)

This commit is contained in:
Zhuohan Li
2023-03-29 14:48:56 +08:00
committed by GitHub
parent d359cda5fa
commit 721fa3df15
15 changed files with 536 additions and 146 deletions

View File

@@ -1,9 +1,7 @@
import torch
from transformers import AutoConfig
from cacheflow.models.utils import get_cpu_memory
from cacheflow.models.utils import get_dtype_size
from cacheflow.models.utils import get_gpu_memory
_GiB = 1 << 30
@@ -31,11 +29,15 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
@@ -106,8 +108,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory = get_gpu_memory()
usable_memory = int(memory_utilization * gpu_memory)
usable_memory = int(memory_utilization * self.gpu_memory)
param_size = self._get_param_size()
act_size = self._get_max_act_size(max_num_batched_tokens)
@@ -122,16 +123,15 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
swap_space: int,
) -> int:
swap_space = swap_space * _GiB
cpu_memory = get_cpu_memory()
if swap_space > 0.8 * cpu_memory:
if swap_space > 0.8 * self.cpu_memory:
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 80% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
f'({self.cpu_memory / _GiB:.2f} GiB).'
'Please check the swap space size.')
if swap_space > 0.5 * cpu_memory:
if swap_space > 0.5 * self.cpu_memory:
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 50% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
f'({self.cpu_memory / _GiB:.2f} GiB).'
'This may slow the system performance.')
max_num_blocks = swap_space // self._get_cache_block_size()
return max_num_blocks

View File

@@ -44,11 +44,14 @@ def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: Union[torch.dtype, str],
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
model_name, block_size, torch_dtype, tensor_parallel_size)
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)
raise ValueError(f'Unsupported model name: {model_name}')

View File

@@ -1,9 +1,5 @@
from typing import Union
import random
import numpy as np
import psutil
import torch
_STR_DTYPE_TO_TORCH_DTYPE = {
@@ -26,10 +22,3 @@ def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
torch_dtype = get_torch_dtype(dtype)
return torch.tensor([], dtype=torch_dtype).element_size()
def get_gpu_memory(gpu: int = 0) -> int:
return torch.cuda.get_device_properties(gpu).total_memory
def get_cpu_memory() -> int:
return psutil.virtual_memory().total