New weight loader without np copy (#52)
This commit is contained in:
@@ -1,11 +1,6 @@
|
||||
"""1D LLaMA model compatible with HuggingFace weights."""
|
||||
import os
|
||||
import glob
|
||||
import filelock
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
@@ -15,6 +10,8 @@ from cacheflow.models.activation import SiluAndMul
|
||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||
from cacheflow.models.layernorm import RMSNorm
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||
@@ -216,76 +213,57 @@ class LlamaForCausalLM(nn.Module):
|
||||
"up_proj.weight"]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self, weights_path: str):
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
if "qkv_proj" in name or "gate_up_proj" in name:
|
||||
if "qkv_proj" in name:
|
||||
original_name = "qkv_proj"
|
||||
weight_names = ["q_proj", "k_proj", "v_proj"]
|
||||
shard_size = param.shape[0] // 3
|
||||
else:
|
||||
original_name = "gate_up_proj"
|
||||
weight_names = ["gate_proj", "up_proj"]
|
||||
shard_size = param.shape[0] // 2
|
||||
weights_to_concat = []
|
||||
for weight_name in weight_names:
|
||||
weight = np.load(os.path.join(
|
||||
weights_path, name.replace(original_name, weight_name)))
|
||||
weights_to_concat.append(weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)])
|
||||
loaded_weight = torch.from_numpy(
|
||||
np.concatenate(weights_to_concat, axis=0))
|
||||
else:
|
||||
loaded_weight = torch.from_numpy(
|
||||
np.load(os.path.join(weights_path, name)))
|
||||
for p in self._column_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
for p in self._row_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
|
||||
assert param.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
@staticmethod
|
||||
def get_weights(model_name: str, path: str):
|
||||
if not os.path.isfile(os.path.join(model_name, "config.json")):
|
||||
raise ValueError("LLaMA model's model_name has to be a path"
|
||||
"to the huggingface model's directory.")
|
||||
path = os.path.join(model_name, f"np")
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
lock_path = os.path.join(path, "file_lock")
|
||||
lock = filelock.FileLock(lock_path)
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
with lock:
|
||||
test_weight_path = os.path.join(path, "model.embed_tokens.weight")
|
||||
if os.path.exists(test_weight_path):
|
||||
return path
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
bin_files = glob.glob(os.path.join(model_name, "*.bin"))
|
||||
|
||||
for bin_file in tqdm(bin_files, desc="Convert format"):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in tqdm(state.items(), leave=False):
|
||||
param_path = os.path.join(path, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
|
||||
return path
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights)
|
||||
|
||||
def initialize_dummy_weights(self) -> None:
|
||||
for param in self.state_dict().values():
|
||||
|
||||
Reference in New Issue
Block a user