[CI/Build] Fix test_prefix_prefill for AMD (#28905)
Signed-off-by: Ryan Rock <ryan.rock@amd.com>
This commit is contained in:
@@ -174,11 +174,11 @@ def test_contexted_kv_attention(
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
|
||||
torch.int32
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
@@ -417,11 +417,11 @@ def test_contexted_kv_attention_alibi(
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
|
||||
torch.int32
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
|
||||
Reference in New Issue
Block a user