Optimize data movement (#20)

This commit is contained in:
Woosuk Kwon
2023-04-02 00:30:17 -07:00
committed by GitHub
parent 1f01a18d39
commit 897cb2ae28
17 changed files with 275 additions and 135 deletions

View File

@@ -85,15 +85,13 @@ def test_rotary_embedding_neox(
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
# Run the kernel.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
out_key = key.clone()
pos_encoding_ops.rotary_embedding_neox(
positions,
out_query,
out_key,
positions,
query,
key,
cos_sin_cache,
)