[Performance] Split FlashAttn attention and cache update (#25954)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Luka Govedič <luka.govedic@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
ElizaWszola
2026-01-24 02:28:06 +01:00
committed by GitHub
parent 0118cdcc02
commit a28b94e6ef
21 changed files with 458 additions and 68 deletions

View File

@@ -9,6 +9,7 @@ import torch
from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
create_vllm_config,
try_backend_includes_kv_cache_update,
try_get_attention_backend,
)
from vllm.config import ParallelConfig, SpeculativeConfig
@@ -120,6 +121,14 @@ def forward_attention(
key = k.view(-1, num_kv_heads, dim_per_head)
value = v.view(-1, num_kv_heads, dim_per_head)
output = torch.empty_like(query)
if not try_backend_includes_kv_cache_update(backend):
instance.do_kv_cache_update(
layer=layer,
key=key,
value=value,
kv_cache=kv_cache,
slot_mapping=attn_metadata.slot_mapping,
)
return instance.forward(
layer=layer,
query=query,