Add pt_load_map_location to allow loading to cuda (#16869)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@@ -1564,6 +1564,16 @@ class LoadConfig:
|
||||
use_tqdm_on_load: bool = True
|
||||
"""Whether to enable tqdm for showing progress bar when loading model
|
||||
weights."""
|
||||
pt_load_map_location: Union[str, dict[str, str]] = "cpu"
|
||||
"""
|
||||
pt_load_map_location: the map location for loading pytorch checkpoint, to
|
||||
support loading checkpoints can only be loaded on certain devices like
|
||||
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
|
||||
mapping from different devices like from GPU 1 to GPU 0:
|
||||
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
|
||||
in dictionary needs to be double quoted for json parsing. For more details,
|
||||
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user