[Kernel] Use out arg in flash_attn_varlen_func (#10811)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2024-12-01 17:55:39 -08:00
committed by GitHub
parent b7954776fd
commit 073a4bd1c0
3 changed files with 21 additions and 7 deletions

View File

@@ -205,10 +205,12 @@ def unified_v1_flash_attention(
v_scale,
)
attn_output = flash_attn_varlen_func(
# Compute attention and update output up to `num_actual_tokens`.
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
@@ -220,8 +222,6 @@ def unified_v1_flash_attention(
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
# TODO(woosuk): Remove this unnecessary copy.
output[:num_actual_tokens].copy_(attn_output)
def unified_v1_flash_attention_fake(