[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@@ -52,3 +53,10 @@ def test_rms_norm(
|
||||
torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
|
||||
else:
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
if residual is not None:
|
||||
opcheck(torch.ops._C.fused_add_rms_norm,
|
||||
(x, residual, layer.weight.data, layer.variance_epsilon))
|
||||
else:
|
||||
opcheck(torch.ops._C.rms_norm,
|
||||
(out, x, layer.weight.data, layer.variance_epsilon))
|
||||
|
||||
Reference in New Issue
Block a user