[Model] Support Mamba (#6484)

This commit is contained in:
Tyler Michael Smith
2024-10-11 11:40:06 -04:00
committed by GitHub
parent df3dcdf49d
commit 7342a7d7f8
29 changed files with 1603 additions and 343 deletions

View File

@@ -196,6 +196,9 @@ class ModelConfig:
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self.is_attention_free = self._init_attention_free()
self.has_inner_state = self._init_has_inner_state()
self.override_neuron_config = override_neuron_config if is_neuron(
) else None
self._verify_embedding_mode()
@@ -216,6 +219,14 @@ class ModelConfig:
return None
def _init_attention_free(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_attention_free_model(architectures)
def _init_has_inner_state(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.model_has_inner_state(architectures)
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral"]:
@@ -438,6 +449,10 @@ class ModelConfig:
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
if self.is_attention_free:
return 0
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
@@ -469,6 +484,9 @@ class ModelConfig:
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)
if self.is_attention_free:
return 0
attributes = [
# For Falcon:
"n_head_kv",
@@ -511,31 +529,17 @@ class ModelConfig:
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return end - start
def contains_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> bool:
"""True for Mamba/SSM models (Jamba)"""
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t == "attention"
])
if self.is_attention_free:
return 0
def _get_num_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t != "attention"
])
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
layers = getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
return len([t for t in layers if t == "attention"])
def get_multimodal_config(self) -> "MultiModalConfig":
"""
@@ -585,6 +589,7 @@ class CacheConfig:
gpu_memory_utilization: float,
swap_space: float,
cache_dtype: str,
is_attention_free: bool = False,
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
@@ -595,6 +600,7 @@ class CacheConfig:
self.swap_space_bytes = swap_space * GiB_bytes
self.num_gpu_blocks_override = num_gpu_blocks_override
self.cache_dtype = cache_dtype
self.is_attention_free = is_attention_free
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb