[V1][Kernel] Refactor the prefix_prefill kernel so that the caller no longer has to pass in the context lengths (#13095)
This commit is contained in:
@@ -150,17 +150,6 @@ class ROCmAttentionImpl(AttentionImpl):
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# TODO(sage): Refactor the context_attention_fwd kernel so that this
|
||||
# overhead can be removed
|
||||
context_lens = torch.empty_like(attn_metadata.seq_lens)
|
||||
batch_size = len(attn_metadata.query_start_loc) - 1
|
||||
assert len(context_lens) == batch_size
|
||||
for i in range(batch_size):
|
||||
query_start = attn_metadata.query_start_loc[i]
|
||||
query_end = attn_metadata.query_start_loc[i + 1]
|
||||
context_lens[i] = attn_metadata.seq_lens[i] - (query_end -
|
||||
query_start)
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
context_attention_fwd(q=query[:num_actual_tokens],
|
||||
k=key[:num_actual_tokens],
|
||||
@@ -172,7 +161,6 @@ class ROCmAttentionImpl(AttentionImpl):
|
||||
b_loc=attn_metadata.block_table,
|
||||
b_start_loc=attn_metadata.query_start_loc,
|
||||
b_seq_len=attn_metadata.seq_lens,
|
||||
b_ctx_len=context_lens,
|
||||
max_input_len=attn_metadata.max_query_len,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
|
||||
Reference in New Issue
Block a user