[Misc] Add uninitialized params tracking for AutoWeightsLoader (#10327)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-18 09:07:46 +08:00
committed by GitHub
parent d1557e66d3
commit c4e464333e
74 changed files with 454 additions and 185 deletions

View File

@@ -8,7 +8,7 @@ import math
import re
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, TypedDict, Union)
Optional, Set, Tuple, TypedDict, Union)
import numpy as np
import torch
@@ -964,13 +964,15 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
@@ -999,6 +1001,8 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class QWenLLM(QWenBaseModel):