Support heterogeneous NemotronHPuzzle model (#32549)

Signed-off-by: <dafrimi@nvidia.com>
Signed-off-by: Daniel Afrimi <dafrimi@nvidia.com>
Signed-off-by: root <dafrimi@nvidia.com>
This commit is contained in:
danielafrimi
2026-01-27 17:55:54 +02:00
committed by GitHub
parent f3a5ee705f
commit 83fb2d09e8
5 changed files with 75 additions and 5 deletions

View File

@@ -603,4 +603,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"FalconMambaForCausalLM": MambaModelConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
}

View File

@@ -354,8 +354,12 @@ class NemotronHMoEDecoderLayer(nn.Module):
super().__init__()
self.config = config
# Get per-layer config for heterogeneous models if exsist
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
layer_config = get_layer_config(layer_idx) if get_layer_config else config
self.mixer = NemotronHMoE(
config,
layer_config,
quant_config=quant_config,
parallel_config=parallel_config,
prefix=f"{prefix}.mixer",
@@ -479,6 +483,9 @@ class NemotronHAttention(nn.Module):
prefix=f"{prefix}.o_proj",
)
# Get per-layer sliding window from config (for heterogeneous models)
sliding_window = getattr(config, "sliding_window", None)
self.attn = Attention(
self.num_heads,
self.head_dim,
@@ -487,6 +494,7 @@ class NemotronHAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
per_layer_sliding_window=sliding_window,
)
def forward(
@@ -514,8 +522,12 @@ class NemotronHAttentionDecoderLayer(nn.Module):
) -> None:
super().__init__()
# Get per-layer config for heterogeneous models if exsist
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
layer_config = get_layer_config(layer_idx) if get_layer_config else config
self.mixer = NemotronHAttention(
config,
layer_config,
layer_idx,
model_config,
cache_config,
@@ -631,6 +643,34 @@ class NemotronHModel(nn.Module):
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
def _get_max_n_routed_experts(self) -> int:
"""Get max n_routed_experts from config or block_configs for puzzle models.
For heterogeneous models with varying expert counts per layer,
returns the MAX to ensure all expert weights can be loaded.
"""
# First try top-level attribute
n_routed_experts = getattr(self.config, "n_routed_experts", None)
if n_routed_experts is not None:
return n_routed_experts
# For puzzle models, get MAX from all MoE blocks in block_configs
# (different layers may have different expert counts)
max_experts = 0
block_configs = getattr(self.config, "block_configs", None)
if block_configs:
for block in block_configs:
if isinstance(block, dict):
if block.get("block_type") == "moe":
max_experts = max(max_experts, block.get("n_routed_experts", 0))
else:
# HF converts dicts to objects with attributes
if getattr(block, "block_type", "") == "moe":
max_experts = max(
max_experts, getattr(block, "n_routed_experts", 0)
)
return max_experts
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if self.has_moe:
# (param_name, weight_name, expert_id, shard_id)
@@ -643,7 +683,7 @@ class NemotronHModel(nn.Module):
ckpt_gate_proj_name="up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="",
num_experts=self.config.n_routed_experts,
num_experts=self._get_max_n_routed_experts(),
num_redundant_experts=getattr(self, "num_redundant_experts", 0),
)
return expert_params_mapping

View File

@@ -163,6 +163,7 @@ _TEXT_GENERATION_MODELS = {
"MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
"NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),