[Model] Support Mamba (#6484)

This commit is contained in:
Tyler Michael Smith
2024-10-11 11:40:06 -04:00
committed by GitHub
parent df3dcdf49d
commit 7342a7d7f8
29 changed files with 1603 additions and 343 deletions

View File

@@ -14,7 +14,8 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
from .interfaces import supports_multimodal, supports_pp
from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp)
from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__)
@@ -52,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
@@ -157,6 +159,8 @@ class _ModelInfo:
is_embedding_model: bool
supports_multimodal: bool
supports_pp: bool
has_inner_state: bool
is_attention_free: bool
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
@@ -165,6 +169,8 @@ class _ModelInfo:
is_embedding_model=is_embedding_model(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
)
@@ -380,6 +386,14 @@ class _ModelRegistry:
) -> bool:
return self.inspect_model_cls(architectures).supports_pp
def model_has_inner_state(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).has_inner_state
def is_attention_free_model(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).is_attention_free
ModelRegistry = _ModelRegistry({
model_arch: _LazyRegisteredModel(