[Kernel] Add cuda kernel for gpt_oss activation (#22951)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-08-17 13:03:24 +08:00
committed by GitHub
parent 87f48623a5
commit 4d4061b6e7
9 changed files with 157 additions and 42 deletions

View File

@@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, MulAndSilu,
NewGELU, QuickGELU,
SiluAndMul)
SiluAndMul, SwigluOAIAndMul)
from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -25,7 +25,15 @@ CUDA_DEVICES = [
@pytest.mark.parametrize(
"activation",
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
[
"silu_and_mul",
"mul_and_silu",
"gelu",
"gelu_tanh",
"fatrelu",
"swigluoai_and_mul",
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@@ -59,18 +67,43 @@ def test_act_and_mul(
threshold = random.uniform(0, 1)
layer = FatreluAndMul(threshold)
fn = torch.ops._C.fatrelu_and_mul
elif activation == "swigluoai_and_mul":
layer = SwigluOAIAndMul()
fn = torch.ops._C.swigluoai_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
if activation == "swigluoai_and_mul":
rtol = {
#For fp16, change the relative tolerance from 1e-3 to 2e-3
torch.float16:
2e-3,
torch.bfloat16:
2e-2,
torch.float:
1.3e-6
}
def _get_rtol(output) -> float:
return rtol[output.dtype]
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=_get_rtol(out))
else:
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "fatrelu":
opcheck(fn, (out, x, threshold))
elif activation == "swigluoai_and_mul":
opcheck(fn, (out, x, layer.alpha, layer.limit))
else:
opcheck(fn, (out, x))