[Misc] Add uninitialized params tracking for AutoWeightsLoader (#10327)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
|
||||
import math
|
||||
from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
|
||||
Optional, Tuple, TypedDict, Union)
|
||||
Optional, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -751,9 +751,10 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user