[Misc/Testing] Use torch.testing.assert_close (#7324)

This commit is contained in:
jon-chuang
2024-08-15 21:24:04 -07:00
committed by GitHub
parent e165528778
commit 50b8d08dbd
25 changed files with 197 additions and 188 deletions

View File

@@ -247,10 +247,10 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
# Check that resetting the lora weights succeeds
@@ -274,10 +274,10 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
expected_result = embedding(torch.cat(inputs))
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@@ -384,10 +384,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
# Check that resetting the lora weights succeeds
@@ -411,10 +411,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
expected_result = expanded_embedding(torch.cat(inputs))
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@@ -541,10 +541,10 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
embedding_bias=None)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@@ -614,10 +614,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
# Check that resetting the lora weights succeeds
@@ -642,10 +642,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@@ -728,10 +728,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
# Check that resetting the lora weights succeeds
@@ -756,10 +756,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@@ -868,10 +868,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)
@@ -900,10 +900,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()