[Performance] Split FlashAttn attention and cache update (#25954)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Luka Govedič <luka.govedic@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import torch
|
||||
from tests.v1.attention.utils import (
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
try_backend_includes_kv_cache_update,
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
@@ -120,6 +121,14 @@ def forward_attention(
|
||||
key = k.view(-1, num_kv_heads, dim_per_head)
|
||||
value = v.view(-1, num_kv_heads, dim_per_head)
|
||||
output = torch.empty_like(query)
|
||||
if not try_backend_includes_kv_cache_update(backend):
|
||||
instance.do_kv_cache_update(
|
||||
layer=layer,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
)
|
||||
return instance.forward(
|
||||
layer=layer,
|
||||
query=query,
|
||||
|
||||
Reference in New Issue
Block a user