[Kernel] Have rotary embeddings support tensors (#18046)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot,
|
||||
@pytest.mark.parametrize("head_size", [32, 108])
|
||||
@pytest.mark.parametrize("seq_len", [11, 1024])
|
||||
@pytest.mark.parametrize("use_key", [True, False])
|
||||
@pytest.mark.parametrize("head_stride_is_contingous", [True, False])
|
||||
def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
||||
is_neox_style, rotary_dim, head_size,
|
||||
seq_len, use_key):
|
||||
seq_len, use_key, head_stride_is_contingous):
|
||||
batch_size = 1
|
||||
base = 10000
|
||||
num_heads = 7
|
||||
@@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
||||
positions = torch.randint(0,
|
||||
max_position, (batch_size, seq_len),
|
||||
device=device)
|
||||
head_stride = head_size + (64 if head_stride_is_contingous else 0)
|
||||
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
num_heads,
|
||||
head_stride,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
query = query[..., :head_size]
|
||||
key = key[..., :head_size] if use_key else None
|
||||
|
||||
rotary_embedding_opcheck(rot, positions, query, key)
|
||||
offsets = torch.zeros(batch_size * seq_len,
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
rotary_embedding_opcheck(rot, positions, query, key, offsets)
|
||||
|
||||
# if we have a contiguous head stride, test the alternate
|
||||
# [..., num_heads * head_dim] shape/layout
|
||||
if head_stride_is_contingous:
|
||||
rotary_embedding_opcheck(
|
||||
rot, positions, query.flatten(start_dim=-2),
|
||||
key.flatten(start_dim=-2) if use_key else None)
|
||||
|
||||
Reference in New Issue
Block a user