[Model] Jamba support (#4115)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: Erez Schwartz <erezs@ai21.com> Co-authored-by: Mor Zusman <morz@ai21.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Tomer Asida <tomera@ai21.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
@@ -386,9 +386,36 @@ class ModelConfig:
|
||||
return num_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||
|
||||
def contains_seqlen_agnostic_layers(
|
||||
self, parallel_config: "ParallelConfig") -> bool:
|
||||
"""True for Mamba/SSM models (Jamba)"""
|
||||
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
|
||||
|
||||
def get_layers_block_type(self,
|
||||
parallel_config: "ParallelConfig") -> List[str]:
|
||||
num_layers = self.get_num_layers(parallel_config)
|
||||
# Transformers supports layers_block_type @property
|
||||
return getattr(self.hf_config, "layers_block_type",
|
||||
["attention"] * num_layers)
|
||||
|
||||
def get_num_attention_layers(self,
|
||||
parallel_config: "ParallelConfig") -> int:
|
||||
return len([
|
||||
t for t in self.get_layers_block_type(parallel_config)
|
||||
if t == "attention"
|
||||
])
|
||||
|
||||
def _get_num_seqlen_agnostic_layers(
|
||||
self, parallel_config: "ParallelConfig") -> int:
|
||||
return len([
|
||||
t for t in self.get_layers_block_type(parallel_config)
|
||||
if t != "attention"
|
||||
])
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
|
||||
Reference in New Issue
Block a user