Enable safetensors loading for all models (#974)

This commit is contained in:
Zhuohan Li
2023-09-07 15:49:52 -07:00
committed by GitHub
parent c07ece5ca4
commit c957c741d9
18 changed files with 143 additions and 83 deletions

View File

@@ -271,8 +271,7 @@ class LlamaForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
use_safetensor: bool = True):
load_format: str = "auto"):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
@@ -289,7 +288,7 @@ class LlamaForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache, use_safetensor):
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue