Add get_expert_mapping to NemotronHModel (for LoRA support) (#31539)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
@@ -632,14 +632,7 @@ class NemotronHModel(nn.Module):
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
if self.has_moe:
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
@@ -653,8 +646,19 @@ class NemotronHModel(nn.Module):
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_redundant_experts=getattr(self, "num_redundant_experts", 0),
|
||||
)
|
||||
else:
|
||||
expert_params_mapping = []
|
||||
return expert_params_mapping
|
||||
|
||||
return []
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
Reference in New Issue
Block a user