Use runtime profiling to replace manual memory analyzers (#81)
This commit is contained in:
@@ -40,3 +40,15 @@ def set_random_seed(seed: int) -> None:
|
||||
|
||||
if model_parallel_is_initialized():
|
||||
model_parallel_cuda_manual_seed(seed)
|
||||
|
||||
|
||||
def get_cache_block_size(block_size: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_layers: int,
|
||||
dtype: str) -> int:
|
||||
key_cache_block = block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
|
||||
Reference in New Issue
Block a user