[Model loader]: support multi-thread model weight loading (#23928)

Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
Signed-off-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Yang Kaiyong
2025-09-09 02:49:39 +08:00
committed by GitHub
parent 7be141b2c5
commit 43d9ad03ba
2 changed files with 105 additions and 12 deletions

View File

@@ -18,8 +18,9 @@ from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -28,6 +29,9 @@ logger = init_logger(__name__)
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8
@dataclasses.dataclass
class Source:
"""A source for weights."""
@@ -52,9 +56,15 @@ class DefaultModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
extra_config = load_config.model_loader_extra_config
allowed_keys = {"enable_multithread_load", "num_threads"}
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}")
def _prepare_weights(
self,
@@ -145,6 +155,7 @@ class DefaultModelLoader(BaseModelLoader):
self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides)
@@ -165,16 +176,34 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = safetensors_weights_iterator(
if extra_config.get("enable_multithread_load"):
weights_iterator = (
multi_thread_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS),
))
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
max_workers=extra_config.get("num_threads",
self.DEFAULT_NUM_THREADS),
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
if current_platform.is_tpu():
from vllm.platforms.tpu import USE_TPU_COMMONS