[Feat] Adds runai distributed streamer (#27230)
Signed-off-by: bbartels <benjamin@bartels.dev> Signed-off-by: Benjamin Bartels <benjamin@bartels.dev> Co-authored-by: omer-dayan <omdayan@nvidia.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -27,9 +27,16 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
self._is_distributed = False
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if "distributed" in extra_config and isinstance(
|
||||
extra_config.get("distributed"), bool
|
||||
):
|
||||
self._is_distributed = extra_config.get("distributed")
|
||||
|
||||
if "concurrency" in extra_config and isinstance(
|
||||
extra_config.get("concurrency"), int
|
||||
):
|
||||
@@ -92,8 +99,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
|
||||
Reference in New Issue
Block a user