[Attention] Unify mamba and attention backend selection (#23171)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
@@ -4,7 +4,10 @@
|
||||
import copy
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.linear_attn import (
|
||||
LinearAttentionBackend)
|
||||
return LinearAttentionBackend
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||
self.model_config.dtype,
|
||||
|
||||
Reference in New Issue
Block a user