[Misc] Add uninitialized params tracking for AutoWeightsLoader (#10327)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-18 09:07:46 +08:00
committed by GitHub
parent d1557e66d3
commit c4e464333e
74 changed files with 454 additions and 185 deletions

View File

@@ -1,7 +1,7 @@
import itertools
from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Protocol, Tuple, Union, overload)
Optional, Protocol, Set, Tuple, Union, overload)
import torch
import torch.nn as nn
@@ -172,8 +172,9 @@ class AutoWeightsLoader:
if module != self.module:
module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights):
module_load_weights(weights)
return
loaded_params = module_load_weights(weights)
yield from map(lambda x: self._get_qualname(base_prefix, x),
loaded_params)
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
@@ -222,11 +223,11 @@ class AutoWeightsLoader:
weights: Iterable[Tuple[str, torch.Tensor]],
*,
mapper: Optional[WeightsMapper] = None,
) -> List[str]:
) -> Set[str]:
if mapper is not None:
weights = mapper.apply(weights)
autoloaded_weights = list(self._load_module("", self.module, weights))
autoloaded_weights = set(self._load_module("", self.module, weights))
return autoloaded_weights