[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -144,7 +144,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap)
|
||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
|
||||
@@ -244,5 +244,5 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap)
|
||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
Reference in New Issue
Block a user