[Attention] Unify mamba and attention backend selection (#23171)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba2"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
Mamba2AttentionBackend)
|
||||
return Mamba2AttentionBackend
|
||||
|
||||
|
||||
def mamba_mixer2(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user