New weight loader without np copy (#52)

This commit is contained in:
Zhuohan Li
2023-05-03 15:32:04 +08:00
committed by GitHub
parent 4858f3bb45
commit 27f1410d06
12 changed files with 284 additions and 352 deletions

View File

@@ -1,19 +1,15 @@
"""1D OPT model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import OPTConfig
from huggingface_hub import snapshot_download
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
@@ -257,73 +253,42 @@ class OPTForCausalLM(nn.Module):
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, weights_path: str):
def load_weights(self, model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
if "lm_head_weight" in name:
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
continue
if "qkv_proj" in name:
if name.startswith("decoder."):
name = "model." + name
is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
weights_to_concat = []
for weight_name in ["q_proj", "k_proj", "v_proj"]:
weight = np.load(os.path.join(
weights_path, name.replace("qkv_proj", weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
@staticmethod
def get_weights(model_name: str, path: str):
path = os.path.join(path, f"{model_name}-np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
lock_path = os.path.join(path, "file_lock")
lock = filelock.FileLock(lock_path)
with lock:
test_weight_path = os.path.join(
path, "model.decoder.embed_positions.weight")
if os.path.exists(test_weight_path):
return path
folder = snapshot_download(model_name, allow_patterns="*.bin",
cache_dir=os.path.join(path, "cache"))
bin_files = glob.glob(os.path.join(folder, "*.bin"))
for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
if name.startswith("decoder."):
name = "model." + name
param_path = os.path.join(path, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
return path
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights)
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():