Accelerate LLaMA model loading (#234)

This commit is contained in:
JFDuan
2023-08-30 16:00:13 +08:00
committed by GitHub
parent becd7a56f1
commit 0d93f15694
8 changed files with 190 additions and 112 deletions

View File

@@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights,
)
from vllm.model_executor.parallel_utils.parallel_state import (
@@ -241,7 +242,7 @@ class QWenLMHeadModel(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "lm_head.weight"]
_column_parallel_weights = []
_row_parallel_weights = ["c_proj.weight"]
def load_weights(
@@ -259,16 +260,6 @@ class QWenLMHeadModel(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
if "wte" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
if "c_attn" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
@@ -306,6 +297,12 @@ class QWenLMHeadModel(nn.Module):
continue
param = state_dict[name]
if "wte" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
tp_rank)
continue
load_tensor_parallel_weights(
param,
loaded_weight,