Enable safetensors loading for all models (#974)
This commit is contained in:
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights,
|
||||
@@ -249,17 +250,19 @@ class QWenLMHeadModel(nn.Module):
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False,
|
||||
load_format: str = "auto",
|
||||
):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
if "c_attn" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
|
||||
Reference in New Issue
Block a user