[Kernel] Use out arg in flash_attn_varlen_func (#10811)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -205,10 +205,12 @@ def unified_v1_flash_attention(
|
||||
v_scale,
|
||||
)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_query_len,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
@@ -220,8 +222,6 @@ def unified_v1_flash_attention(
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
# TODO(woosuk): Remove this unnecessary copy.
|
||||
output[:num_actual_tokens].copy_(attn_output)
|
||||
|
||||
|
||||
def unified_v1_flash_attention_fake(
|
||||
|
||||
Reference in New Issue
Block a user