[Kernel] add kernel for FATReLU (#9610)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-10-24 16:18:27 +08:00
committed by GitHub
parent 8a02cd045a
commit 295a061fb3
6 changed files with 78 additions and 8 deletions

View File

@@ -1,12 +1,13 @@
import random
from typing import Type
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, QuickGELU,
SiluAndMul)
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, NewGELU,
QuickGELU, SiluAndMul)
from vllm.utils import seed_everything
from .allclose_default import get_default_atol, get_default_rtol
@@ -20,7 +21,8 @@ CUDA_DEVICES = [
]
@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
@pytest.mark.parametrize("activation",
["silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@@ -47,16 +49,23 @@ def test_act_and_mul(
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
fn = torch.ops._C.gelu_tanh_and_mul
elif activation == "fatrelu":
threshold = random.uniform(0, 1)
layer = FatreluAndMul(threshold)
fn = torch.ops._C.fatrelu_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (out, x))
if activation == "fatrelu":
opcheck(fn, (out, x, threshold))
else:
opcheck(fn, (out, x))
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),