[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -90,5 +90,7 @@ def test_logits_processors(seed: int, device: str):
|
||||
assert torch.isinf(logits_processor_output[:, 0]).all()
|
||||
|
||||
fake_logits *= logits_processor.scale
|
||||
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
|
||||
1e-4)
|
||||
torch.testing.assert_close(logits_processor_output[:, 1],
|
||||
fake_logits[:, 1],
|
||||
rtol=1e-4,
|
||||
atol=0.0)
|
||||
|
||||
Reference in New Issue
Block a user