[Model] support minicpm3 (#8297)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -270,38 +270,47 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.max_position_embeddings = getattr(config,
|
||||
"max_position_embeddings", 8192)
|
||||
self._init_attn_block()
|
||||
self._init_ffn_block()
|
||||
|
||||
def _init_attn_block(self):
|
||||
self.input_layernorm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
self.self_attn = MiniCPMAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
num_heads=self.config.num_attention_heads,
|
||||
num_kv_heads=self.config.num_key_value_heads,
|
||||
rope_theta=self.rope_theta,
|
||||
rope_scaling=self.rope_scaling,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
|
||||
def _init_ffn_block(self):
|
||||
self.post_attention_layernorm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
if self.num_experts == 0:
|
||||
self.mlp = MiniCPMMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
intermediate_size=self.config.intermediate_size,
|
||||
hidden_act=self.config.hidden_act,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
else:
|
||||
self.mlp = MiniCPMMoE(num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.mlp = MiniCPMMoE(
|
||||
num_experts=self.config.num_experts,
|
||||
top_k=self.config.num_experts_per_tok,
|
||||
hidden_size=self.config.hidden_size,
|
||||
intermediate_size=self.config.intermediate_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -344,6 +353,8 @@ class MiniCPMModel(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
@@ -354,12 +365,16 @@ class MiniCPMModel(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MiniCPMDecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self._init_layers()
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def _init_layers(self):
|
||||
self.layers = nn.ModuleList([
|
||||
MiniCPMDecoderLayer(self.config, self.cache_config,
|
||||
self.quant_config)
|
||||
for _ in range(self.config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
embedding = self.embed_tokens(input_ids)
|
||||
return embedding * self.config.scale_emb
|
||||
@@ -431,13 +446,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self.quant_config = quant_config
|
||||
self.model = MiniCPMModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
self._init_model()
|
||||
unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
@@ -458,6 +471,12 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def _init_model(self):
|
||||
self.model = MiniCPMModel(config=self.config,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
lora_config=self.lora_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user