[PERF] Add conv1d metadata to GDN attn (#25105)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionMetadata)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
||||
|
||||
@@ -45,8 +46,8 @@ class Mamba2Metadata:
|
||||
"""
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||
@@ -117,7 +118,8 @@ def prepare_mamba2_metadata(
|
||||
|
||||
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
||||
mamba2_metadata: Union[Mamba2Metadata,
|
||||
Mamba2AttentionMetadata]):
|
||||
Mamba2AttentionMetadata,
|
||||
GDNAttentionMetadata]):
|
||||
"""
|
||||
this is triggered upon handling a new input at the first layer
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user