[Misc] add use_tqdm_on_load to reduce logs (#14407)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham
2025-03-08 08:57:46 -05:00
committed by GitHub
parent 03fe18ae0f
commit 0b7f06b447
4 changed files with 54 additions and 22 deletions

View File

@@ -354,11 +354,18 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.download_dir,
hf_folder,
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
@@ -806,9 +813,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
iterator = safetensors_weights_iterator(hf_weights_files)
iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
iterator = pt_weights_iterator(hf_weights_files)
iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
@@ -1396,7 +1409,10 @@ class RunaiModelStreamerLoader(BaseModelLoader):
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(hf_weights_files)
return runai_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary"""

View File

@@ -366,16 +366,22 @@ def filter_files_not_needed_for_inference(
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
def enable_tqdm(use_tqdm_on_load: bool):
return use_tqdm_on_load and (not torch.distributed.is_initialized()
or torch.distributed.get_rank() == 0)
def np_cache_weights_iterator(
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
hf_weights_files: List[str]
model_name_or_path: str,
cache_dir: Optional[str],
hf_folder: str,
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model np files.
Will dump the model weights to numpy files if they are not already dumped.
"""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
@@ -389,7 +395,7 @@ def np_cache_weights_iterator(
for bin_file in tqdm(
hf_weights_files,
desc="Loading np_cache checkpoint shards",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file,
@@ -414,15 +420,14 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator(
hf_weights_files: List[str]
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
@@ -432,16 +437,15 @@ def safetensors_weights_iterator(
def runai_safetensors_weights_iterator(
hf_weights_files: List[str]
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
@@ -449,15 +453,14 @@ def runai_safetensors_weights_iterator(
def pt_weights_iterator(
hf_weights_files: List[str]
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu", weights_only=True)