[V1][Kernel] Refactor the prefix_prefill kernel so that the caller no longer has to pass in the context lengths (#13095)
This commit is contained in:
@@ -100,7 +100,7 @@ def test_contexted_kv_attention(
|
||||
BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
@@ -154,7 +154,6 @@ def test_contexted_kv_attention(
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
@@ -171,7 +170,6 @@ def test_contexted_kv_attention(
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
@@ -333,7 +331,7 @@ def test_contexted_kv_attention_alibi(
|
||||
BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
@@ -387,7 +385,6 @@ def test_contexted_kv_attention_alibi(
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
@@ -404,7 +401,6 @@ def test_contexted_kv_attention_alibi(
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
|
||||
Reference in New Issue
Block a user