Add pt_load_map_location to allow loading to cuda (#16869)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@@ -384,6 +384,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
@@ -890,6 +891,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
for org_name, param in iterator:
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
|
||||
@@ -502,6 +502,7 @@ def fastsafetensors_weights_iterator(
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
use_tqdm_on_load: bool,
|
||||
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
for bin_file in tqdm(
|
||||
@@ -510,7 +511,9 @@ def pt_weights_iterator(
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||
state = torch.load(bin_file,
|
||||
map_location=pt_load_map_location,
|
||||
weights_only=True)
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
Reference in New Issue
Block a user