[Misc] Add CustomOp interface for device portability (#5255)
This commit is contained in:
@@ -44,7 +44,7 @@ def test_act_and_mul(
|
||||
elif activation == "gelu_tanh":
|
||||
layer = GeluAndMul(approximate="tanh")
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
ref_out = layer.forward_native(x)
|
||||
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
||||
# implementations, so we can do exact comparison.
|
||||
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
|
||||
@@ -72,7 +72,7 @@ def test_activation(
|
||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||
layer = activation()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
ref_out = layer.forward_native(x)
|
||||
assert torch.allclose(out,
|
||||
ref_out,
|
||||
atol=get_default_atol(out),
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_rms_norm(
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_out = layer._forward(x, residual)
|
||||
ref_out = layer.forward_native(x, residual)
|
||||
out = layer(x, residual)
|
||||
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
||||
# numerical errors than other operators because they involve reductions.
|
||||
|
||||
@@ -64,7 +64,7 @@ def test_rotary_embedding(
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope._forward(positions, query, key)
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
out_query, out_key = rope.forward(positions, query, key)
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query,
|
||||
@@ -121,7 +121,7 @@ def test_batched_rotary_embedding(
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope._forward(positions, query, key)
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
out_query, out_key = rope.forward(positions,
|
||||
query,
|
||||
key,
|
||||
@@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope._forward(positions, query, key, query_offsets)
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key,
|
||||
query_offsets)
|
||||
out_query, out_key = rope.forward(positions, query, key,
|
||||
query_offsets.flatten())
|
||||
# Compare the results.
|
||||
|
||||
Reference in New Issue
Block a user