[Misc] Add CustomOp interface for device portability (#5255)
This commit is contained in:
@@ -42,7 +42,7 @@ def test_rms_norm(
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_out = layer._forward(x, residual)
|
||||
ref_out = layer.forward_native(x, residual)
|
||||
out = layer(x, residual)
|
||||
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
|
||||
# numerical errors than other operators because they involve reductions.
|
||||
|
||||
Reference in New Issue
Block a user