[Misc] Add uninitialized params tracking for AutoWeightsLoader (#10327)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-18 09:07:46 +08:00
committed by GitHub
parent d1557e66d3
commit c4e464333e
74 changed files with 454 additions and 185 deletions

View File

@@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
@@ -513,7 +513,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -523,6 +524,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
@@ -543,6 +545,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
@@ -576,6 +579,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should