Optimize data movement (#20)
This commit is contained in:
@@ -85,15 +85,13 @@ def test_rotary_embedding_neox(
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
||||
|
||||
# Run the kernel.
|
||||
out_query = torch.empty_like(query)
|
||||
out_key = torch.empty_like(key)
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
out_key = key.clone()
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
positions,
|
||||
out_query,
|
||||
out_key,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
cos_sin_cache,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user