[Model] Support Mamba (#6484)
This commit is contained in:
committed by
GitHub
parent
df3dcdf49d
commit
7342a7d7f8
@@ -271,7 +271,7 @@ class HasInnerState(Protocol):
|
||||
"""
|
||||
A flag that indicates this model has inner state.
|
||||
Models that has inner state usually need access to the scheduler_config
|
||||
for max_num_seqs ,etc... (Currently only used by Jamba)
|
||||
for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -307,3 +307,46 @@ def has_inner_state(
|
||||
return isinstance(model, _HasInnerStateType)
|
||||
|
||||
return isinstance(model, HasInnerState)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IsAttentionFree(Protocol):
|
||||
"""The interface required for all models like Mamba that lack attention,
|
||||
but do have state whose size is constant wrt the number of tokens."""
|
||||
|
||||
is_attention_free: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model has no attention.
|
||||
Used for block manager and attention backend selection.
|
||||
True for Mamba but not Jamba.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _IsAttentionFreeType(Protocol):
|
||||
is_attention_free: ClassVar[Literal[True]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
|
||||
...
|
||||
|
||||
|
||||
def is_attention_free(
|
||||
model: Union[Type[object], object]
|
||||
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _IsAttentionFreeType)
|
||||
|
||||
return isinstance(model, IsAttentionFree)
|
||||
|
||||
Reference in New Issue
Block a user