[Misc] Add CustomOp interface for device portability (#5255)

This commit is contained in:
Woosuk Kwon
2024-06-05 09:18:19 -07:00
committed by GitHub
parent 974fc9b845
commit 41ca62cf03
7 changed files with 100 additions and 27 deletions

View File

@@ -44,7 +44,7 @@ def test_act_and_mul(
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
out = layer(x)
ref_out = layer._forward(x)
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
@@ -72,7 +72,7 @@ def test_activation(
x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation()
out = layer(x)
ref_out = layer._forward(x)
ref_out = layer.forward_native(x)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),