Refactor system architecture (#109)
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Dict, List, Tuple
|
||||
import torch
|
||||
|
||||
from cacheflow import cache_ops
|
||||
from cacheflow.config import CacheConfig, ModelConfig, ParallelConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@@ -18,27 +19,22 @@ class CacheEngine:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> None:
|
||||
if head_size % 16 != 0:
|
||||
raise ValueError(
|
||||
f'head_size ({head_size}) must be a multiple of 16.')
|
||||
self.cache_config = cache_config
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.dtype = dtype
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.num_heads = model_config.get_num_heads(parallel_config)
|
||||
self.dtype = model_config.dtype
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||
|
||||
# Initialize the cache.
|
||||
self.gpu_cache = self.allocate_gpu_cache()
|
||||
@@ -48,7 +44,7 @@ class CacheEngine:
|
||||
self.cache_stream = torch.cuda.Stream()
|
||||
assert self.cache_stream != torch.cuda.current_stream()
|
||||
# Initialize the events for stream synchronization.
|
||||
self.events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
||||
|
||||
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
||||
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
@@ -133,3 +129,23 @@ class CacheEngine:
|
||||
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
||||
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
block_size: int,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_heads(parallel_config)
|
||||
num_layers = model_config.get_num_layers(parallel_config)
|
||||
|
||||
key_cache_block = block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = _get_dtype_size(model_config.dtype)
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
def _get_dtype_size(dtype: torch.dtype) -> int:
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
Reference in New Issue
Block a user