[Model] Support Mamba (#6484)
This commit is contained in:
committed by
GitHub
parent
df3dcdf49d
commit
7342a7d7f8
@@ -14,7 +14,8 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import supports_multimodal, supports_pp
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_multimodal, supports_pp)
|
||||
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -52,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
||||
@@ -157,6 +159,8 @@ class _ModelInfo:
|
||||
is_embedding_model: bool
|
||||
supports_multimodal: bool
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
is_attention_free: bool
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
@@ -165,6 +169,8 @@ class _ModelInfo:
|
||||
is_embedding_model=is_embedding_model(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
has_inner_state=has_inner_state(model),
|
||||
is_attention_free=is_attention_free(model),
|
||||
)
|
||||
|
||||
|
||||
@@ -380,6 +386,14 @@ class _ModelRegistry:
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_pp
|
||||
|
||||
def model_has_inner_state(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).has_inner_state
|
||||
|
||||
def is_attention_free_model(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_attention_free
|
||||
|
||||
|
||||
ModelRegistry = _ModelRegistry({
|
||||
model_arch: _LazyRegisteredModel(
|
||||
|
||||
Reference in New Issue
Block a user