Clean up kernel unit tests (#938)

This commit is contained in:
Woosuk Kwon
2023-09-06 08:57:38 +09:00
committed by GitHub
parent 22379d5513
commit fbd80ad409
6 changed files with 368 additions and 403 deletions

View File

@@ -1,11 +1,19 @@
from typing import Tuple
from typing import Optional, Tuple
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import pos_encoding_ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
SEEDS = [0]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
@@ -74,16 +82,28 @@ class RefRotaryEmbeddingNeox(nn.Module):
return query, key
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def run_rotary_embedding_neox(
def test_rotary_embedding_neox(
num_tokens: int,
num_heads: int,
head_size: int,
max_position: int,
rotary_dim: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
query = torch.randn(num_tokens,
num_heads * head_size,
@@ -97,7 +117,7 @@ def run_rotary_embedding_neox(
# Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
@@ -129,19 +149,5 @@ def run_rotary_embedding_neox(
ref_key = ref_key.view(num_tokens, num_heads * head_size)
# Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
def test_rotary_embedding_neox() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}')
run_rotary_embedding_neox(
num_tokens=2145,
num_heads=5,
head_size=head_size,
max_position=8192,
rotary_dim=head_size,
dtype=dtype,
)
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)