[Model] Jamba support (#4115)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: Erez Schwartz <erezs@ai21.com>
Co-authored-by: Mor Zusman <morz@ai21.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Tomer Asida <tomera@ai21.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Mor Zusman
2024-07-03 02:11:29 +03:00
committed by GitHub
parent ee93f4f92a
commit 9d6a8daa87
21 changed files with 1192 additions and 34 deletions

View File

@@ -386,9 +386,36 @@ class ModelConfig:
return num_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
def contains_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> bool:
"""True for Mamba/SSM models (Jamba)"""
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t == "attention"
])
def _get_num_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t != "attention"
])
class CacheConfig:
"""Configuration for the KV cache.