Accelerate LLaMA model loading (#234)

This commit is contained in:
JFDuan
2023-08-30 16:00:13 +08:00
committed by GitHub
parent becd7a56f1
commit 0d93f15694
8 changed files with 190 additions and 112 deletions

View File

@@ -36,8 +36,9 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
@@ -263,15 +264,15 @@ class LlamaForCausalLM(nn.Module):
return next_tokens
_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"gate_proj.weight", "up_proj.weight"
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
use_np_cache: bool = False,
use_safetensor: bool = True):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
@@ -288,20 +289,10 @@ class LlamaForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache, use_safetensor):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] * tp_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
@@ -339,6 +330,12 @@ class LlamaForCausalLM(nn.Module):
continue
param = state_dict[name]
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tensor_model_parallel_rank)
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,