[Model][Bugfix] Add FATReLU activation and support for openbmb/MiniCPM-S-1B-sft (#9396)

This commit is contained in:
Junhao Li
2024-10-16 12:40:24 -04:00
committed by GitHub
parent fb60ae9b91
commit 5b8a1fde84
3 changed files with 37 additions and 5 deletions

View File

@@ -13,6 +13,33 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, threshold: float = 0.):
super().__init__()
self.threshold = threshold
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, self.threshold, 0.0)
return x1 * x2
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.