Add activation registry (#126)

This commit is contained in:
Woosuk Kwon
2023-05-25 00:09:07 -07:00
committed by GitHub
parent 057daef778
commit 4a151dd453
5 changed files with 22 additions and 13 deletions

View File

@@ -4,6 +4,21 @@ import torch.nn as nn
from cacheflow import activation_ops
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
"relu": nn.ReLU(),
}
def get_act_fn(act_fn: str) -> nn.Module:
"""Get an activation function by name."""
act_fn = act_fn.lower()
if act_fn in _ACTIVATION_REGISTRY:
return _ACTIVATION_REGISTRY[act_fn]
raise ValueError(f"Activation function {act_fn!r} is not supported.")
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.