[Core] Integrate fastsafetensors loader for loading model weights (#10647)
Signed-off-by: Manish Sethi <Manish.sethi1@ibm.com>
This commit is contained in:
@@ -49,9 +49,10 @@ from vllm.model_executor.model_loader.utils import (ParamMapping,
|
||||
set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator,
|
||||
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
||||
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference, get_gguf_extra_tensor_names,
|
||||
get_lock, gguf_quant_weights_iterator, initialize_dummy_weights,
|
||||
np_cache_weights_iterator, pt_weights_iterator,
|
||||
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
@@ -275,7 +276,8 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == LoadFormat.AUTO:
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == LoadFormat.SAFETENSORS:
|
||||
elif (load_format == LoadFormat.SAFETENSORS
|
||||
or load_format == LoadFormat.FASTSAFETENSORS):
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == LoadFormat.MISTRAL:
|
||||
@@ -357,10 +359,16 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
|
||||
weights_iterator = fastsafetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
|
||||
@@ -38,6 +38,14 @@ except (ImportError, OSError):
|
||||
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
|
||||
"SafetensorsStreamer")
|
||||
|
||||
try:
|
||||
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
|
||||
except ImportError:
|
||||
fastsafetensors = PlaceholderModule("fastsafetensors")
|
||||
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
|
||||
"SafeTensorsFileLoader")
|
||||
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# use system-level temp directory for file locks, so that multiple users
|
||||
@@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator(
|
||||
yield from streamer.get_tensors()
|
||||
|
||||
|
||||
def fastsafetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files
|
||||
using fastsafetensor library."""
|
||||
if torch.distributed.is_initialized():
|
||||
pg = torch.distributed.group.WORLD
|
||||
else:
|
||||
pg = SingleGroup()
|
||||
|
||||
device = torch.device(f'cuda:{pg.rank()}')
|
||||
weight_files_sub_lists = [
|
||||
hf_weights_files[i:i + pg.size()]
|
||||
for i in range(0, len(hf_weights_files), pg.size())
|
||||
]
|
||||
|
||||
for f_list in tqdm(
|
||||
weight_files_sub_lists,
|
||||
desc="Loading safetensors using Fastsafetensor loader",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
loader = SafeTensorsFileLoader(pg, device)
|
||||
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
|
||||
loader.add_filenames(rank_file_map)
|
||||
try:
|
||||
fb = loader.copy_files_to_device()
|
||||
try:
|
||||
keys = list(fb.key_to_rank_lidx.keys())
|
||||
for k in keys:
|
||||
t = fb.get_tensor(k)
|
||||
yield k, t
|
||||
finally:
|
||||
fb.close()
|
||||
finally:
|
||||
loader.close()
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
use_tqdm_on_load: bool,
|
||||
|
||||
Reference in New Issue
Block a user