@@ -6,16 +6,20 @@ import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
||||
from cacheflow.models.llama import LlamaForCausalLM
|
||||
from cacheflow.models.opt import OPTForCausalLM
|
||||
from cacheflow.models.utils import get_torch_dtype
|
||||
|
||||
|
||||
_MODELS = {
|
||||
'llama': LlamaForCausalLM,
|
||||
'opt': OPTForCausalLM,
|
||||
}
|
||||
|
||||
_MEMORY_ANALYZERS = {
|
||||
'llama': LlamaMemoryAnalyzer,
|
||||
'opt': OPTMemoryAnalyzer,
|
||||
}
|
||||
|
||||
@@ -31,7 +35,7 @@ def get_model(
|
||||
for model_class_name, model_class in _MODELS.items():
|
||||
if model_class_name in model_name:
|
||||
# Download model weights if it's not cached.
|
||||
weights_dir = model_class.download_weights(model_name, path=path)
|
||||
weights_dir = model_class.get_weights(model_name, path=path)
|
||||
# Create a model instance.
|
||||
model = model_class(config)
|
||||
# Load the weights from the cached or downloaded files.
|
||||
|
||||
Reference in New Issue
Block a user