[Performance] Split FlashAttn attention and cache update (#25954)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Luka Govedič <luka.govedic@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
ElizaWszola
2026-01-24 02:28:06 +01:00
committed by GitHub
parent 0118cdcc02
commit a28b94e6ef
21 changed files with 458 additions and 68 deletions

View File

@@ -15,7 +15,7 @@ from vllm.v1.attention.backend import (
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
subclass_attention_backend,
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
@@ -72,6 +72,7 @@ def create_cross_attention_backend(
) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
underlying_impl = underlying_attn_backend.get_impl_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore
def build(
@@ -106,18 +107,60 @@ def create_cross_attention_backend(
)
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
new_metadata.slot_mapping = _get_cross_slot_mapping(
slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens_cpu,
new_metadata.block_table_tensor,
self.kv_cache_spec,
self.device,
)
return super().build(common_prefix_len, new_metadata, fast_build)
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
attn_metadata.slot_mapping = slot_mapping
return attn_metadata
attn_backend = subclass_attention_backend(
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
# `CrossAttentionBuilder` instead of the one computed by `BlockTable`
# (gpu_model_runner)
class CrossAttentionImpl(underlying_impl): # type: ignore[valid-type,misc]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
)
return super().forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output,
output_scale,
output_block_scale,
)
attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=CrossAttentionBuilder,
overrides={
"get_builder_cls": lambda: CrossAttentionBuilder,
"get_impl_cls": lambda: CrossAttentionImpl,
"forward_includes_kv_cache_update": True,
},
)
return attn_backend