[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -77,7 +77,7 @@ def test_prepare_prompt(batch_size):
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills > 0
|
||||
assert attn_metadata.num_decode_tokens == 0
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
@@ -90,7 +90,7 @@ def test_prepare_prompt(batch_size):
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
@@ -102,10 +102,10 @@ def test_prepare_prompt(batch_size):
|
||||
start_idx += seq_len
|
||||
seq_start_loc.append(start_idx)
|
||||
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
|
||||
dtype=torch.int,
|
||||
@@ -114,7 +114,7 @@ def test_prepare_prompt(batch_size):
|
||||
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device)
|
||||
assert torch.allclose(attn_metadata.block_tables, expected)
|
||||
torch.testing.assert_close(attn_metadata.block_tables, expected)
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
|
||||
@@ -201,7 +201,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# decode has only 1 token for query.
|
||||
start_idx += 1
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
@@ -210,15 +210,15 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
seq_start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.tensor(context_lens, dtype=torch.int, device=device))
|
||||
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||
assert torch.allclose(
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
||||
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user