[Feat] Adds runai distributed streamer (#27230)

Signed-off-by: bbartels <benjamin@bartels.dev>
Signed-off-by: Benjamin Bartels <benjamin@bartels.dev>
Co-authored-by: omer-dayan <omdayan@nvidia.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Benjamin Bartels
2025-10-30 04:09:10 +00:00
committed by GitHub
parent 2ce5c5d3d6
commit 17d055f527
9 changed files with 39 additions and 11 deletions

View File

@@ -27,9 +27,16 @@ class RunaiModelStreamerLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
self._is_distributed = False
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if "distributed" in extra_config and isinstance(
extra_config.get("distributed"), bool
):
self._is_distributed = extra_config.get("distributed")
if "concurrency" in extra_config and isinstance(
extra_config.get("concurrency"), int
):
@@ -92,8 +99,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
)
def download_model(self, model_config: ModelConfig) -> None: