[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -100,11 +100,11 @@ def test_sample_decoding_only(random_sampling, max_best_of,
|
||||
if modify_greedy_probs and not request_uses_random_sampling:
|
||||
# If we are modifying greedy probs and the request is greedy,
|
||||
# we want to make sure the probs tensor is modified in place
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
probs[i][sampled_tokens[i]],
|
||||
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
|
||||
assert torch.sum(probs[i]) == 1.0
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
sampled_modified_probs[i][0],
|
||||
torch.full_like(sampled_modified_probs[i][0], 1.0))
|
||||
elif request_uses_random_sampling:
|
||||
@@ -117,8 +117,8 @@ def test_sample_decoding_only(random_sampling, max_best_of,
|
||||
# If the request is greedy and we are not modifying greedy probs,
|
||||
# we want to make sure sampled_modified_probs tensor is the same as
|
||||
# the probs tensor.
|
||||
assert torch.allclose(sampled_modified_probs[i][0],
|
||||
probs[i][sampled_tokens[i]])
|
||||
torch.testing.assert_close(sampled_modified_probs[i],
|
||||
probs[i][sampled_tokens[i]])
|
||||
|
||||
if save_logprobs:
|
||||
assert sampled_logprobs.shape == (bs, max_best_of)
|
||||
|
||||
Reference in New Issue
Block a user