Add pt_load_map_location to allow loading to cuda (#16869)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-05-01 23:23:42 -07:00
committed by GitHub
parent f192ca90e6
commit 109e15a335
6 changed files with 74 additions and 3 deletions

View File

@@ -384,6 +384,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
if current_platform.is_tpu():
@@ -890,6 +891,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving

View File

@@ -502,6 +502,7 @@ def fastsafetensors_weights_iterator(
def pt_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
for bin_file in tqdm(
@@ -510,7 +511,9 @@ def pt_weights_iterator(
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu", weights_only=True)
state = torch.load(bin_file,
map_location=pt_load_map_location,
weights_only=True)
yield from state.items()
del state