[Misc] Add uninitialized params tracking for AutoWeightsLoader (#10327)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -34,7 +34,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -345,7 +345,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -353,6 +354,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
("qkv_proj", "v_proj", "v")
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
@@ -383,3 +385,5 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user