Refactor system architecture (#109)
This commit is contained in:
165
cacheflow/config.py
Normal file
165
cacheflow/config.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
download_dir: Optional[str],
|
||||
use_np_weights: bool,
|
||||
use_dummy_weights: bool,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.download_dir = download_dir
|
||||
self.use_np_weights = use_np_weights
|
||||
self.use_dummy_weights = use_dummy_weights
|
||||
self.seed = seed
|
||||
|
||||
self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||
if total_num_attention_heads % tensor_parallel_size != 0:
|
||||
raise ValueError(
|
||||
f"Total number of attention heads ({total_num_attention_heads})"
|
||||
" must be divisible by tensor parallel size "
|
||||
f"({tensor_parallel_size}).")
|
||||
|
||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
||||
raise ValueError(
|
||||
f"Total number of hidden layers ({total_num_hidden_layers}) "
|
||||
"must be divisible by pipeline parallel size "
|
||||
f"({pipeline_parallel_size}).")
|
||||
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.hf_config.hidden_size
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: int,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space = swap_space
|
||||
|
||||
# Will be set after profiling.
|
||||
self.num_gpu_blocks = None
|
||||
self.num_cpu_blocks = None
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
use_ray: bool,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.use_ray = use_ray
|
||||
|
||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
||||
if self.world_size > 1:
|
||||
self.use_ray = True
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is not supported yet.")
|
||||
|
||||
|
||||
class SchedulerConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_seqs: int,
|
||||
) -> None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_num_seqs = max_num_seqs
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
"float": torch.float32,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
dtype: str,
|
||||
) -> torch.dtype:
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
dtype = dtype.lower()
|
||||
if dtype == "default":
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32 models.
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is not allowed.
|
||||
raise ValueError(
|
||||
f"Cannot use {torch_dtype} for {config_dtype} model.")
|
||||
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
if compute_capability[0] < 8:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||
return torch_dtype
|
||||
Reference in New Issue
Block a user