Add PyTorch-native implementation of custom layers (#1898)

This commit is contained in:
Woosuk Kwon
2023-12-02 21:18:40 -08:00
committed by GitHub
parent 5313c2cb8b
commit 9b294976a2
6 changed files with 149 additions and 184 deletions

View File

@@ -1,9 +1,7 @@
import pytest
import torch
import torch.nn.functional as F
from transformers.activations import get_activation
from vllm._C import ops
from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
@@ -11,11 +9,6 @@ D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0]
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(chunks=2, dim=1)
return F.silu(x1) * x2
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@@ -30,9 +23,9 @@ def test_silu_and_mul(
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
ops.silu_and_mul(out, x)
ref_out = ref_silu_and_mul(x)
layer = SiluAndMul()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@@ -50,9 +43,9 @@ def test_gelu_new(
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x)
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@@ -69,7 +62,7 @@ def test_gelu_fast(
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x)
layer = FastGELU()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)