[Misc/Testing] Use torch.testing.assert_close (#7324)

This commit is contained in:
jon-chuang
2024-08-15 21:24:04 -07:00
committed by GitHub
parent e165528778
commit 50b8d08dbd
25 changed files with 197 additions and 188 deletions

View File

@@ -90,5 +90,7 @@ def test_logits_processors(seed: int, device: str):
assert torch.isinf(logits_processor_output[:, 0]).all()
fake_logits *= logits_processor.scale
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
1e-4)
torch.testing.assert_close(logits_processor_output[:, 1],
fake_logits[:, 1],
rtol=1e-4,
atol=0.0)