New weight loader without np copy (#52)

This commit is contained in:
Zhuohan Li
2023-05-03 15:32:04 +08:00
committed by GitHub
parent 4858f3bb45
commit 27f1410d06
12 changed files with 284 additions and 352 deletions

View File

@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional
import torch
import torch.nn as nn
@@ -32,8 +32,9 @@ _MEMORY_ANALYZERS = {
def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
path: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
@@ -46,15 +47,13 @@ def get_model(
model = model_class(config)
model = model.cuda()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
# random values to the weights.
model.initialize_dummy_weights()
else:
# Download model weights if it's not cached.
weights_dir = model_class.get_weights(model_name, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(weights_dir)
model.load_weights(model_name, cache_dir, use_np_cache)
model = model.cuda()
return model.eval(), torch_dtype
raise ValueError(f'Unsupported model name: {model_name}')