[Model][Bugfix] Add FATReLU activation and support for openbmb/MiniCPM-S-1B-sft (#9396)

This commit is contained in:
Junhao Li
2024-10-16 12:40:24 -04:00
committed by GitHub
parent fb60ae9b91
commit 5b8a1fde84
3 changed files with 37 additions and 5 deletions

View File

@@ -33,7 +33,7 @@ from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -152,6 +152,7 @@ class MiniCPMMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
hidden_act_param: float,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
@@ -163,10 +164,13 @@ class MiniCPMMLP(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
if hidden_act == "silu":
self.act_fn = SiluAndMul()
elif hidden_act == "fatrelu":
self.act_fn = FatreluAndMul(threshold=hidden_act_param)
else:
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
"Only silu and fatrelu are supported for now.")
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
@@ -304,6 +308,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
intermediate_size=self.config.intermediate_size,
hidden_act=self.config.hidden_act,
hidden_act_param=getattr(self.config, "hidden_act_param", 0.),
quant_config=self.quant_config,
)
else: