[ROCm] Fix some kernels failed unit tests (#2498)

This commit is contained in:
Hongxia Yang
2024-02-05 17:25:36 -05:00
committed by GitHub
parent 72d3a30c63
commit 56f738ae9b
5 changed files with 62 additions and 12 deletions

View File

@@ -0,0 +1,18 @@
import torch
# Reference default values of atol and rtol are from
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
default_rtol = {
torch.float16: 1e-3,
torch.bfloat16: 1.6e-2,
torch.float: 1.3e-6
}
def get_default_atol(output) -> float:
return default_atol[output.dtype]
def get_default_rtol(output) -> float:
return default_rtol[output.dtype]