Set weights_only=True when using torch.load() (#12366)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -93,7 +93,7 @@ def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
) -> None:
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
shared = _shared_pointers(loaded)
|
||||
@@ -381,7 +381,9 @@ def np_cache_weights_iterator(
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
state = torch.load(bin_file,
|
||||
map_location="cpu",
|
||||
weights_only=True)
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
@@ -447,7 +449,7 @@ def pt_weights_iterator(
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||
yield from state.items()
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user