[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -632,7 +632,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
|
||||
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
||||
torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user