[Core] Integrate fastsafetensors loader for loading model weights (#10647)

Signed-off-by: Manish Sethi <Manish.sethi1@ibm.com>
This commit is contained in:
Manish Sethi
2025-03-24 11:08:02 -04:00
committed by GitHub
parent 9606d572ed
commit 761702fd19
11 changed files with 152 additions and 9 deletions

View File

@@ -49,9 +49,10 @@ from vllm.model_executor.model_loader.utils import (ParamMapping,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, get_gguf_extra_tensor_names,
get_lock, gguf_quant_weights_iterator, initialize_dummy_weights,
np_cache_weights_iterator, pt_weights_iterator,
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@@ -275,7 +276,8 @@ class DefaultModelLoader(BaseModelLoader):
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS:
elif (load_format == LoadFormat.SAFETENSORS
or load_format == LoadFormat.FASTSAFETENSORS):
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
@@ -357,10 +359,16 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,

View File

@@ -38,6 +38,14 @@ except (ImportError, OSError):
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
"SafetensorsStreamer")
try:
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
except ImportError:
fastsafetensors = PlaceholderModule("fastsafetensors")
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
"SafeTensorsFileLoader")
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users
@@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator(
yield from streamer.get_tensors()
def fastsafetensors_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if torch.distributed.is_initialized():
pg = torch.distributed.group.WORLD
else:
pg = SingleGroup()
device = torch.device(f'cuda:{pg.rank()}')
weight_files_sub_lists = [
hf_weights_files[i:i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())
]
for f_list in tqdm(
weight_files_sub_lists,
desc="Loading safetensors using Fastsafetensor loader",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
loader = SafeTensorsFileLoader(pg, device)
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def pt_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,