Add pt_load_map_location to allow loading to cuda (#16869)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-05-01 23:23:42 -07:00
committed by GitHub
parent f192ca90e6
commit 109e15a335
6 changed files with 74 additions and 3 deletions

View File

@@ -1564,6 +1564,16 @@ class LoadConfig:
use_tqdm_on_load: bool = True
"""Whether to enable tqdm for showing progress bar when loading model
weights."""
pt_load_map_location: Union[str, dict[str, str]] = "cpu"
"""
pt_load_map_location: the map location for loading pytorch checkpoint, to
support loading checkpoints can only be loaded on certain devices like
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
mapping from different devices like from GPU 1 to GPU 0:
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
in dictionary needs to be double quoted for json parsing. For more details,
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
"""
def compute_hash(self) -> str:
"""