[Misc/Testing] Use torch.testing.assert_close (#7324)

This commit is contained in:
jon-chuang
2024-08-15 21:24:04 -07:00
committed by GitHub
parent e165528778
commit 50b8d08dbd
25 changed files with 197 additions and 188 deletions

View File

@@ -67,14 +67,14 @@ def test_rotary_embedding(
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
assert torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
assert torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
torch.testing.assert_close(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@@ -129,14 +129,14 @@ def test_batched_rotary_embedding(
dtype=torch.long,
device=device))
# Compare the results.
assert torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
assert torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
torch.testing.assert_close(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@@ -200,14 +200,14 @@ def test_batched_rotary_embedding_multi_lora(
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
# Compare the results.
assert torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
assert torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
torch.testing.assert_close(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
@torch.inference_mode()