[Build] Avoid building too many extensions (#1624)

This commit is contained in:
Yanming W
2023-11-23 16:31:19 -08:00
committed by GitHub
parent de23687d16
commit e0c6f556e8
25 changed files with 206 additions and 272 deletions

View File

@@ -3,7 +3,7 @@ import torch
import torch.nn.functional as F
from transformers.activations import get_activation
from vllm import activation_ops
from vllm._C import ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
@@ -31,7 +31,7 @@ def test_silu_and_mul(
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.silu_and_mul(out, x)
ops.silu_and_mul(out, x)
ref_out = ref_silu_and_mul(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@@ -51,7 +51,7 @@ def test_gelu_new(
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_new(out, x)
ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@@ -70,6 +70,6 @@ def test_gelu_fast(
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_fast(out, x)
ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)