[Performance] Split FlashAttn attention and cache update (#25954)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Luka Govedič <luka.govedic@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@@ -15,7 +15,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||
@@ -72,6 +72,7 @@ def create_cross_attention_backend(
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "CrossAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
underlying_impl = underlying_attn_backend.get_impl_cls()
|
||||
|
||||
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||
def build(
|
||||
@@ -106,18 +107,60 @@ def create_cross_attention_backend(
|
||||
)
|
||||
|
||||
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
|
||||
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||
slot_mapping = _get_cross_slot_mapping(
|
||||
new_metadata.encoder_seq_lens_cpu,
|
||||
new_metadata.block_table_tensor,
|
||||
self.kv_cache_spec,
|
||||
self.device,
|
||||
)
|
||||
return super().build(common_prefix_len, new_metadata, fast_build)
|
||||
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
return attn_metadata
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
|
||||
# `CrossAttentionBuilder` instead of the one computed by `BlockTable`
|
||||
# (gpu_model_runner)
|
||||
class CrossAttentionImpl(underlying_impl): # type: ignore[valid-type,misc]
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
not underlying_attn_backend.forward_includes_kv_cache_update
|
||||
and attn_metadata is not None
|
||||
):
|
||||
self.do_kv_cache_update(
|
||||
layer, key, value, kv_cache, attn_metadata.slot_mapping
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output,
|
||||
output_scale,
|
||||
output_block_scale,
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend_with_overrides(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=CrossAttentionBuilder,
|
||||
overrides={
|
||||
"get_builder_cls": lambda: CrossAttentionBuilder,
|
||||
"get_impl_cls": lambda: CrossAttentionImpl,
|
||||
"forward_includes_kv_cache_update": True,
|
||||
},
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
Reference in New Issue
Block a user