Add model_utils
This commit is contained in:
13
cacheflow/worker/models/model_utils.py
Normal file
13
cacheflow/worker/models/model_utils.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow.worker.models.opt import OPTForCausalLM
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'opt': OPTForCausalLM,
|
||||
}
|
||||
|
||||
|
||||
def get_model(model_name: str) -> nn.Module:
|
||||
if model_name not in MODEL_CLASSES:
|
||||
raise ValueError(f'Invalid model name: {model_name}')
|
||||
return MODEL_CLASSES[model_name].from_pretrained(model_name)
|
||||
Reference in New Issue
Block a user