[PERF] Add conv1d metadata to GDN attn (#25105)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2025-09-18 18:27:49 +04:00
committed by GitHub
parent 01a583fea4
commit 072d7e53e5
5 changed files with 24 additions and 8 deletions

View File

@@ -11,6 +11,7 @@ from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.platforms import current_platform
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
@@ -45,8 +46,8 @@ class Mamba2Metadata:
"""
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
@@ -117,7 +118,8 @@ def prepare_mamba2_metadata(
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
mamba2_metadata: Union[Mamba2Metadata,
Mamba2AttentionMetadata]):
Mamba2AttentionMetadata,
GDNAttentionMetadata]):
"""
this is triggered upon handling a new input at the first layer
"""