[Misc] Add uninitialized params tracking for AutoWeightsLoader (#10327)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Protocol, Tuple, Union, overload)
|
||||
Optional, Protocol, Set, Tuple, Union, overload)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -172,8 +172,9 @@ class AutoWeightsLoader:
|
||||
if module != self.module:
|
||||
module_load_weights = getattr(module, "load_weights", None)
|
||||
if callable(module_load_weights):
|
||||
module_load_weights(weights)
|
||||
return
|
||||
loaded_params = module_load_weights(weights)
|
||||
yield from map(lambda x: self._get_qualname(base_prefix, x),
|
||||
loaded_params)
|
||||
|
||||
child_modules = dict(module.named_children())
|
||||
child_params = dict(module.named_parameters(recurse=False))
|
||||
@@ -222,11 +223,11 @@ class AutoWeightsLoader:
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
*,
|
||||
mapper: Optional[WeightsMapper] = None,
|
||||
) -> List[str]:
|
||||
) -> Set[str]:
|
||||
if mapper is not None:
|
||||
weights = mapper.apply(weights)
|
||||
|
||||
autoloaded_weights = list(self._load_module("", self.module, weights))
|
||||
autoloaded_weights = set(self._load_module("", self.module, weights))
|
||||
return autoloaded_weights
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user