fix qwen-14b model (#1173)

This commit is contained in:
Qing
2023-09-28 07:33:16 +08:00
committed by GitHub
parent 30e775281d
commit 28e616c4e3
2 changed files with 32 additions and 43 deletions

View File

@@ -141,17 +141,17 @@ class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig):
super().__init__()
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = QWenAttention(config.n_embd,
self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta)
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
def forward(
self,
@@ -190,11 +190,11 @@ class QWenModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size,
config.n_embd,
config.hidden_size,
perform_initialization=False)
self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
@@ -230,7 +230,7 @@ class QWenLMHeadModel(nn.Module):
self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(
config.n_embd,
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,