[Misc] Add uninitialized params tracking for AutoWeightsLoader (#10327)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-18 09:07:46 +08:00
committed by GitHub
parent d1557e66d3
commit c4e464333e
74 changed files with 454 additions and 185 deletions

View File

@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
@@ -369,13 +369,15 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
@@ -402,3 +404,5 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params