Enable safetensors loading for all models (#974)

This commit is contained in:
Zhuohan Li
2023-09-07 15:49:52 -07:00
committed by GitHub
parent c07ece5ca4
commit c957c741d9
18 changed files with 143 additions and 83 deletions

View File

@@ -35,8 +35,8 @@ from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
PagedAttentionWithALiBi)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
convert_pyslice_to_tensor, hf_model_weights_iterator,
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@@ -303,16 +303,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
load_format: str = "auto"):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, load_format):
if "rotary_emb.inv_freq" in name:
continue
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size