[Misc] Add CustomOp interface for device portability (#5255)

This commit is contained in:
Woosuk Kwon
2024-06-05 09:18:19 -07:00
committed by GitHub
parent 974fc9b845
commit 41ca62cf03
7 changed files with 100 additions and 27 deletions

View File

@@ -42,7 +42,7 @@ def test_rms_norm(
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_out = layer._forward(x, residual)
ref_out = layer.forward_native(x, residual)
out = layer(x, residual)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.