[Model] Support Mamba (#6484)
This commit is contained in:
committed by
GitHub
parent
df3dcdf49d
commit
7342a7d7f8
@@ -196,6 +196,9 @@ class ModelConfig:
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
self.is_attention_free = self._init_attention_free()
|
||||
self.has_inner_state = self._init_has_inner_state()
|
||||
|
||||
self.override_neuron_config = override_neuron_config if is_neuron(
|
||||
) else None
|
||||
self._verify_embedding_mode()
|
||||
@@ -216,6 +219,14 @@ class ModelConfig:
|
||||
|
||||
return None
|
||||
|
||||
def _init_attention_free(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_attention_free_model(architectures)
|
||||
|
||||
def _init_has_inner_state(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.model_has_inner_state(architectures)
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow", "mistral"]:
|
||||
@@ -438,6 +449,10 @@ class ModelConfig:
|
||||
# FlashAttention supports only head_size 32, 64, 128, 256,
|
||||
# we need to pad head_size 192 to 256
|
||||
return 256
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
if hasattr(self.hf_text_config, "head_dim"):
|
||||
return self.hf_text_config.head_dim
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
@@ -469,6 +484,9 @@ class ModelConfig:
|
||||
return getattr(self.hf_config.attn_config, "kv_n_heads",
|
||||
self.hf_config.num_attention_heads)
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
attributes = [
|
||||
# For Falcon:
|
||||
"n_head_kv",
|
||||
@@ -511,31 +529,17 @@ class ModelConfig:
|
||||
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
||||
return end - start
|
||||
|
||||
def contains_seqlen_agnostic_layers(
|
||||
self, parallel_config: "ParallelConfig") -> bool:
|
||||
"""True for Mamba/SSM models (Jamba)"""
|
||||
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
|
||||
|
||||
def get_layers_block_type(self,
|
||||
parallel_config: "ParallelConfig") -> List[str]:
|
||||
num_layers = self.get_num_layers(parallel_config)
|
||||
# Transformers supports layers_block_type @property
|
||||
return getattr(self.hf_config, "layers_block_type",
|
||||
["attention"] * num_layers)
|
||||
|
||||
def get_num_attention_layers(self,
|
||||
parallel_config: "ParallelConfig") -> int:
|
||||
return len([
|
||||
t for t in self.get_layers_block_type(parallel_config)
|
||||
if t == "attention"
|
||||
])
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
def _get_num_seqlen_agnostic_layers(
|
||||
self, parallel_config: "ParallelConfig") -> int:
|
||||
return len([
|
||||
t for t in self.get_layers_block_type(parallel_config)
|
||||
if t != "attention"
|
||||
])
|
||||
num_layers = self.get_num_layers(parallel_config)
|
||||
|
||||
# Transformers supports layers_block_type @property
|
||||
layers = getattr(self.hf_config, "layers_block_type",
|
||||
["attention"] * num_layers)
|
||||
return len([t for t in layers if t == "attention"])
|
||||
|
||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||
"""
|
||||
@@ -585,6 +589,7 @@ class CacheConfig:
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: float,
|
||||
cache_dtype: str,
|
||||
is_attention_free: bool = False,
|
||||
num_gpu_blocks_override: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_prefix_caching: bool = False,
|
||||
@@ -595,6 +600,7 @@ class CacheConfig:
|
||||
self.swap_space_bytes = swap_space * GiB_bytes
|
||||
self.num_gpu_blocks_override = num_gpu_blocks_override
|
||||
self.cache_dtype = cache_dtype
|
||||
self.is_attention_free = is_attention_free
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.cpu_offload_gb = cpu_offload_gb
|
||||
|
||||
Reference in New Issue
Block a user