[V1][Kernel] Refactor the prefix_prefill kernel so that the caller no longer has to pass in the context lengths (#13095)

This commit is contained in:
Sage Moore
2025-02-22 05:25:41 -08:00
committed by GitHub
parent e109e598c7
commit 558db8083c
6 changed files with 12 additions and 31 deletions

View File

@@ -150,17 +150,6 @@ class ROCmAttentionImpl(AttentionImpl):
layer._v_scale,
)
# TODO(sage): Refactor the context_attention_fwd kernel so that this
# overhead can be removed
context_lens = torch.empty_like(attn_metadata.seq_lens)
batch_size = len(attn_metadata.query_start_loc) - 1
assert len(context_lens) == batch_size
for i in range(batch_size):
query_start = attn_metadata.query_start_loc[i]
query_end = attn_metadata.query_start_loc[i + 1]
context_lens[i] = attn_metadata.seq_lens[i] - (query_end -
query_start)
# Compute attention and update output up to `num_actual_tokens`.
context_attention_fwd(q=query[:num_actual_tokens],
k=key[:num_actual_tokens],
@@ -172,7 +161,6 @@ class ROCmAttentionImpl(AttentionImpl):
b_loc=attn_metadata.block_table,
b_start_loc=attn_metadata.query_start_loc,
b_seq_len=attn_metadata.seq_lens,
b_ctx_len=context_lens,
max_input_len=attn_metadata.max_query_len,
k_scale=layer._k_scale,
v_scale=layer._v_scale,