New weight loader without np copy (#52)
This commit is contained in:
@@ -1,6 +1,16 @@
|
||||
from typing import Union
|
||||
import os
|
||||
import glob
|
||||
import json
|
||||
import filelock
|
||||
from typing import Union, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank)
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
'half': torch.half,
|
||||
@@ -22,3 +32,86 @@ def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
return torch.tensor([], dtype=torch_dtype).element_size()
|
||||
|
||||
|
||||
class Disabledtqdm(tqdm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def hf_model_weights_iterator(model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
# Prepare file lock directory to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
||||
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
||||
|
||||
# Download model weights from huggingface.
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
with lock:
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns="*.bin",
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))
|
||||
|
||||
if use_np_cache:
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, 'np')
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, 'weight_names.json')
|
||||
with lock:
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names = []
|
||||
for bin_file in hf_bin_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, 'w') as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
with open(weight_names_file, 'r') as f:
|
||||
weight_names = json.load(f)
|
||||
|
||||
for name in weight_names:
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "rb") as f:
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
else:
|
||||
for bin_file in hf_bin_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
|
||||
|
||||
def load_tensor_parallel_weights(param, loaded_weight, param_name,
|
||||
column_parallel_weight_names,
|
||||
row_parallel_weight_names):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
for p in column_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
for p in row_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
assert param.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user