[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -247,10 +247,10 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
@@ -274,10 +274,10 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
expected_result = embedding(torch.cat(inputs))
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -384,10 +384,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
@@ -411,10 +411,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
expected_result = expanded_embedding(torch.cat(inputs))
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -541,10 +541,10 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
embedding_bias=None)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -614,10 +614,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
@@ -642,10 +642,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
||||
expected_result = linear(torch.cat(inputs))[0]
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -728,10 +728,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
@@ -756,10 +756,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
expected_result = linear(torch.cat(inputs))[0]
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -868,10 +868,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
for slot_idx in range(max_loras):
|
||||
lora_linear.reset_lora(slot_idx)
|
||||
@@ -900,10 +900,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
expected_result = linear(torch.cat(inputs))[0]
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user