Clean up kernel unit tests (#938)
This commit is contained in:
@@ -1,11 +1,19 @@
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
@@ -74,16 +82,28 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
return query, key
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_rotary_embedding_neox(
|
||||
def test_rotary_embedding_neox(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
max_position: int,
|
||||
rotary_dim: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
|
||||
query = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
@@ -97,7 +117,7 @@ def run_rotary_embedding_neox(
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
@@ -129,19 +149,5 @@ def run_rotary_embedding_neox(
|
||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
||||
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_rotary_embedding_neox() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Running tests for head_size={head_size} and dtype={dtype}')
|
||||
run_rotary_embedding_neox(
|
||||
num_tokens=2145,
|
||||
num_heads=5,
|
||||
head_size=head_size,
|
||||
max_position=8192,
|
||||
rotary_dim=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
|
||||
|
||||
Reference in New Issue
Block a user