[Model] Update support for NemotronNAS models (#15008)

Signed-off-by: Nave Assaf <nassaf@nvidia.com>
This commit is contained in:
Naveassaf
2025-03-31 15:35:14 +03:00
committed by GitHub
parent 555aa21905
commit 3aa2b6a637
8 changed files with 524 additions and 133 deletions

View File

@@ -411,6 +411,7 @@ class ModelConfig:
self.is_attention_free = self._init_attention_free()
self.is_hybrid = self._init_is_hybrid()
self.has_noops = self._init_has_noops()
self.has_inner_state = self._init_has_inner_state()
if current_platform.is_neuron():
@@ -510,6 +511,10 @@ class ModelConfig:
def _init_is_hybrid(self) -> bool:
return self.registry.is_hybrid_model(self.architectures)
def _init_has_noops(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return self.registry.is_noops_model(architectures)
def _init_has_inner_state(self) -> bool:
return self.registry.model_has_inner_state(self.architectures)
@@ -872,6 +877,14 @@ class ModelConfig:
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)
if self.hf_config.model_type == "nemotron-nas":
for block in self.hf_config.block_configs:
if not block.attention.no_op:
return self.hf_config.num_attention_heads \
// block.attention.n_heads_in_group
raise RuntimeError("Couldn't determine number of kv heads")
if self.is_attention_free:
return 0
@@ -940,7 +953,9 @@ class ModelConfig:
# This function relies on 'layers_block_type' in hf_config,
# for w/o this attribute, we will need to have workarounds like so
attn_block_type = block_type == LayerBlockType.attention
is_transformer = not self.is_hybrid and not self.is_attention_free
is_transformer = not self.is_hybrid and \
not self.has_noops and \
not self.is_attention_free
start, end = self.get_layers_start_end_indices(parallel_config)
if is_transformer:
@@ -951,6 +966,10 @@ class ModelConfig:
# Note that this code assumes there
# is only one type of attention-free block type.
return 0 if attn_block_type else end - start
elif self.has_noops:
block_configs = self.hf_config.block_configs
return sum(not bc.attention.no_op
for bc in block_configs[start:end])
else:
# Hybrid model
layers_block_type_value = getattr(self.hf_config,