Implement approximate GELU kernels (#828)

This commit is contained in:
Woosuk Kwon
2023-08-23 07:43:21 +09:00
committed by GitHub
parent a41c20435e
commit d64bf1646c
4 changed files with 164 additions and 18 deletions

View File

@@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
from transformers.activations import get_activation
from vllm import activation_ops
@@ -28,3 +28,45 @@ def test_silu_and_mul() -> None:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_silu_and_mul(num_tokens, d, dtype)
@torch.inference_mode()
def run_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_gelu_new() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_gelu_new(num_tokens, d, dtype)
@torch.inference_mode()
def run_gelu_fast(
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_gelu_fast() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_gelu_fast(num_tokens, d, dtype)