Support heterogeneous NemotronHPuzzle model (#32549)
Signed-off-by: <dafrimi@nvidia.com> Signed-off-by: Daniel Afrimi <dafrimi@nvidia.com> Signed-off-by: root <dafrimi@nvidia.com>
This commit is contained in:
@@ -603,4 +603,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"FalconMambaForCausalLM": MambaModelConfig,
|
||||
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
|
||||
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
|
||||
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
|
||||
}
|
||||
|
||||
@@ -354,8 +354,12 @@ class NemotronHMoEDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Get per-layer config for heterogeneous models if exsist
|
||||
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
|
||||
layer_config = get_layer_config(layer_idx) if get_layer_config else config
|
||||
|
||||
self.mixer = NemotronHMoE(
|
||||
config,
|
||||
layer_config,
|
||||
quant_config=quant_config,
|
||||
parallel_config=parallel_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
@@ -479,6 +483,9 @@ class NemotronHAttention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
# Get per-layer sliding window from config (for heterogeneous models)
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
@@ -487,6 +494,7 @@ class NemotronHAttention(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
per_layer_sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -514,8 +522,12 @@ class NemotronHAttentionDecoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Get per-layer config for heterogeneous models if exsist
|
||||
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
|
||||
layer_config = get_layer_config(layer_idx) if get_layer_config else config
|
||||
|
||||
self.mixer = NemotronHAttention(
|
||||
config,
|
||||
layer_config,
|
||||
layer_idx,
|
||||
model_config,
|
||||
cache_config,
|
||||
@@ -631,6 +643,34 @@ class NemotronHModel(nn.Module):
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def _get_max_n_routed_experts(self) -> int:
|
||||
"""Get max n_routed_experts from config or block_configs for puzzle models.
|
||||
|
||||
For heterogeneous models with varying expert counts per layer,
|
||||
returns the MAX to ensure all expert weights can be loaded.
|
||||
"""
|
||||
# First try top-level attribute
|
||||
n_routed_experts = getattr(self.config, "n_routed_experts", None)
|
||||
if n_routed_experts is not None:
|
||||
return n_routed_experts
|
||||
|
||||
# For puzzle models, get MAX from all MoE blocks in block_configs
|
||||
# (different layers may have different expert counts)
|
||||
max_experts = 0
|
||||
block_configs = getattr(self.config, "block_configs", None)
|
||||
if block_configs:
|
||||
for block in block_configs:
|
||||
if isinstance(block, dict):
|
||||
if block.get("block_type") == "moe":
|
||||
max_experts = max(max_experts, block.get("n_routed_experts", 0))
|
||||
else:
|
||||
# HF converts dicts to objects with attributes
|
||||
if getattr(block, "block_type", "") == "moe":
|
||||
max_experts = max(
|
||||
max_experts, getattr(block, "n_routed_experts", 0)
|
||||
)
|
||||
return max_experts
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
if self.has_moe:
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
@@ -643,7 +683,7 @@ class NemotronHModel(nn.Module):
|
||||
ckpt_gate_proj_name="up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_experts=self._get_max_n_routed_experts(),
|
||||
num_redundant_experts=getattr(self, "num_redundant_experts", 0),
|
||||
)
|
||||
return expert_params_mapping
|
||||
|
||||
@@ -163,6 +163,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
|
||||
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
|
||||
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
|
||||
"NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
|
||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||
|
||||
Reference in New Issue
Block a user