Add get_expert_mapping to NemotronHModel (for LoRA support) (#31539)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
danisereb
2025-12-31 07:09:03 +02:00
committed by GitHub
parent 578c8f51f6
commit 108a2728f7

View File

@@ -632,14 +632,7 @@ class NemotronHModel(nn.Module):
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if self.has_moe:
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
@@ -653,8 +646,19 @@ class NemotronHModel(nn.Module):
num_experts=self.config.n_routed_experts,
num_redundant_experts=getattr(self, "num_redundant_experts", 0),
)
else:
expert_params_mapping = []
return expert_params_mapping
return []
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()