[Misc/Testing] Use torch.testing.assert_close (#7324)

This commit is contained in:
jon-chuang
2024-08-15 21:24:04 -07:00
committed by GitHub
parent e165528778
commit 50b8d08dbd
25 changed files with 197 additions and 188 deletions

View File

@@ -100,11 +100,11 @@ def test_sample_decoding_only(random_sampling, max_best_of,
if modify_greedy_probs and not request_uses_random_sampling:
# If we are modifying greedy probs and the request is greedy,
# we want to make sure the probs tensor is modified in place
assert torch.allclose(
torch.testing.assert_close(
probs[i][sampled_tokens[i]],
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
assert torch.sum(probs[i]) == 1.0
assert torch.allclose(
torch.testing.assert_close(
sampled_modified_probs[i][0],
torch.full_like(sampled_modified_probs[i][0], 1.0))
elif request_uses_random_sampling:
@@ -117,8 +117,8 @@ def test_sample_decoding_only(random_sampling, max_best_of,
# If the request is greedy and we are not modifying greedy probs,
# we want to make sure sampled_modified_probs tensor is the same as
# the probs tensor.
assert torch.allclose(sampled_modified_probs[i][0],
probs[i][sampled_tokens[i]])
torch.testing.assert_close(sampled_modified_probs[i],
probs[i][sampled_tokens[i]])
if save_logprobs:
assert sampled_logprobs.shape == (bs, max_best_of)