Introduce LLM class for offline inference (#115)
This commit is contained in:
@@ -3,6 +3,8 @@ from typing import Optional
|
||||
import torch
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
_GiB = 1 << 30
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
@@ -70,7 +72,7 @@ class CacheConfig:
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space = swap_space
|
||||
self.swap_space_bytes = swap_space * _GiB
|
||||
|
||||
# Will be set after profiling.
|
||||
self.num_gpu_blocks = None
|
||||
@@ -138,6 +140,8 @@ def _get_and_verify_dtype(
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
# Verify the dtype.
|
||||
|
||||
Reference in New Issue
Block a user