[V1][Kernel] Refactor the prefix_prefill kernel so that the caller no longer has to pass in the context lengths (#13095)

This commit is contained in:
Sage Moore
2025-02-22 05:25:41 -08:00
committed by GitHub
parent e109e598c7
commit 558db8083c
6 changed files with 12 additions and 31 deletions

View File

@@ -100,7 +100,7 @@ def test_contexted_kv_attention(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.long),
dim=0)
max_input_len = MAX_SEQ_LEN
@@ -154,7 +154,6 @@ def test_contexted_kv_attention(
block_table,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
@@ -171,7 +170,6 @@ def test_contexted_kv_attention(
block_table,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
@@ -333,7 +331,7 @@ def test_contexted_kv_attention_alibi(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.long),
dim=0)
max_input_len = MAX_SEQ_LEN
@@ -387,7 +385,6 @@ def test_contexted_kv_attention_alibi(
block_table,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
@@ -404,7 +401,6 @@ def test_contexted_kv_attention_alibi(
block_table,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,