Implement LLaMA (#9)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Woosuk Kwon
2023-03-29 21:25:32 -07:00
committed by GitHub
parent a1b3de86cd
commit 80a2f812f1
7 changed files with 500 additions and 35 deletions

View File

@@ -6,16 +6,20 @@ import torch.nn as nn
from transformers import AutoConfig
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
from cacheflow.models.llama import LlamaForCausalLM
from cacheflow.models.opt import OPTForCausalLM
from cacheflow.models.utils import get_torch_dtype
_MODELS = {
'llama': LlamaForCausalLM,
'opt': OPTForCausalLM,
}
_MEMORY_ANALYZERS = {
'llama': LlamaMemoryAnalyzer,
'opt': OPTMemoryAnalyzer,
}
@@ -31,7 +35,7 @@ def get_model(
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
# Download model weights if it's not cached.
weights_dir = model_class.download_weights(model_name, path=path)
weights_dir = model_class.get_weights(model_name, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.