[Misc] Add uninitialized params tracking for AutoWeightsLoader (#10327)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-18 09:07:46 +08:00
committed by GitHub
parent d1557e66d3
commit c4e464333e
74 changed files with 454 additions and 185 deletions

View File

@@ -3,7 +3,8 @@
"""Inference-only ChatGLM model compatible with THUDM weights."""
from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict)
import torch
from PIL import Image
@@ -645,7 +646,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
"transformer.vision.linear_proj.merged_proj.weight": {
@@ -655,6 +657,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
}
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
is_weight_to_be_merge = False
for _, merged_weight_dict in merged_weights_dict.items():
@@ -677,6 +680,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
for combined_name, merged_weight_dict in merged_weights_dict.items():
if combined_name in params_dict:
@@ -686,3 +690,5 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, combined_weight)
loaded_params.add(combined_name)
return loaded_params