[Config][Disaggregated] Add timeout configuration for the torch.store and add KVTransferConfig.kv_connector_extra_config (#14367)
Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
@@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel):
|
|||||||
# The KV connector port, used to build distributed connection
|
# The KV connector port, used to build distributed connection
|
||||||
kv_port: int = 14579
|
kv_port: int = 14579
|
||||||
|
|
||||||
|
# any extra config that the connector may need
|
||||||
|
kv_connector_extra_config: dict[str, Any] = {}
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
@@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel):
|
|||||||
return self.kv_connector is not None and \
|
return self.kv_connector is not None and \
|
||||||
self.kv_role in ["kv_consumer", "kv_both"]
|
self.kv_role in ["kv_consumer", "kv_both"]
|
||||||
|
|
||||||
|
def get_from_extra_config(self, key, default) -> Any:
|
||||||
|
return self.kv_connector_extra_config.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
class CompilationLevel:
|
class CompilationLevel:
|
||||||
# constants for the levels of the compilation process
|
# constants for the levels of the compilation process
|
||||||
|
|||||||
@@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
self.device = self._select_device(device)
|
self.device = self._select_device(device)
|
||||||
|
|
||||||
# build distributed connection and send/recv implementation
|
# build distributed connection and send/recv implementation
|
||||||
|
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
|
||||||
self.group = StatelessProcessGroup.create(
|
self.group = StatelessProcessGroup.create(
|
||||||
host=self.config.kv_ip,
|
host=self.config.kv_ip,
|
||||||
port=self.config.kv_port + port_offset,
|
port=self.config.kv_port + port_offset,
|
||||||
rank=self.kv_rank,
|
rank=self.kv_rank,
|
||||||
world_size=self.kv_parallel_size,
|
world_size=self.kv_parallel_size,
|
||||||
|
store_timeout=store_timeout,
|
||||||
)
|
)
|
||||||
# add a barrier to make sure the connection is initiated properly
|
# add a barrier to make sure the connection is initiated properly
|
||||||
self.group.barrier()
|
self.group.barrier()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import datetime
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -217,6 +218,7 @@ class StatelessProcessGroup:
|
|||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
data_expiration_seconds: int = 3600,
|
data_expiration_seconds: int = 3600,
|
||||||
|
store_timeout: int = 300,
|
||||||
) -> "StatelessProcessGroup":
|
) -> "StatelessProcessGroup":
|
||||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||||
pollute the global state.
|
pollute the global state.
|
||||||
@@ -238,6 +240,7 @@ class StatelessProcessGroup:
|
|||||||
port=port,
|
port=port,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
is_master=(rank == 0),
|
is_master=(rank == 0),
|
||||||
|
timeout=datetime.timedelta(seconds=store_timeout),
|
||||||
)
|
)
|
||||||
|
|
||||||
return StatelessProcessGroup(
|
return StatelessProcessGroup(
|
||||||
|
|||||||
Reference in New Issue
Block a user