Implement approximate GELU kernels (#828)

This commit is contained in:
Woosuk Kwon
2023-08-23 07:43:21 +09:00
committed by GitHub
parent a41c20435e
commit d64bf1646c
4 changed files with 164 additions and 18 deletions

View File

@@ -4,9 +4,25 @@ void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
m.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
m.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
}