New weight loader without np copy (#52)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -32,8 +32,9 @@ _MEMORY_ANALYZERS = {
|
||||
def get_model(
|
||||
model_name: str,
|
||||
dtype: Union[torch.dtype, str],
|
||||
path: str,
|
||||
cache_dir: Optional[str],
|
||||
use_dummy_weights: bool,
|
||||
use_np_cache: bool,
|
||||
) -> nn.Module:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
torch.set_default_dtype(torch_dtype)
|
||||
@@ -46,15 +47,13 @@ def get_model(
|
||||
model = model_class(config)
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For precise performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
# random values to the weights.
|
||||
model.initialize_dummy_weights()
|
||||
else:
|
||||
# Download model weights if it's not cached.
|
||||
weights_dir = model_class.get_weights(model_name, path=path)
|
||||
# Create a model instance.
|
||||
model = model_class(config)
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(weights_dir)
|
||||
model.load_weights(model_name, cache_dir, use_np_cache)
|
||||
model = model.cuda()
|
||||
return model.eval(), torch_dtype
|
||||
raise ValueError(f'Unsupported model name: {model_name}')
|
||||
|
||||
Reference in New Issue
Block a user