Enable safetensors loading for all models (#974)

This commit is contained in:
Zhuohan Li
2023-09-07 15:49:52 -07:00
committed by GitHub
parent c07ece5ca4
commit c957c741d9
18 changed files with 143 additions and 83 deletions

View File

@@ -297,12 +297,12 @@ class OPTForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "lm_head.weight" in name:
continue