Enable safetensors loading for all models (#974)
This commit is contained in:
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights)
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
@@ -259,14 +259,14 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
@@ -295,6 +295,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
head_start = tensor_model_parallel_rank * num_heads
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
wq, wk, wv = torch.split(
|
||||
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
||||
dim=0)
|
||||
|
||||
Reference in New Issue
Block a user