[Misc] Add CustomOp interface for device portability (#5255)

This commit is contained in:
Woosuk Kwon
2024-06-05 09:18:19 -07:00
committed by GitHub
parent 974fc9b845
commit 41ca62cf03
7 changed files with 100 additions and 27 deletions

View File

@@ -64,7 +64,7 @@ def test_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key)
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
assert torch.allclose(out_query,
@@ -121,7 +121,7 @@ def test_batched_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key)
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions,
query,
key,
@@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key, query_offsets)
ref_query, ref_key = rope.forward_native(positions, query, key,
query_offsets)
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
# Compare the results.