[Config][Disaggregated] Add timeout configuration for the torch.store and add KVTransferConfig.kv_connector_extra_config (#14367)
Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
@@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel):
|
|||||||
# The KV connector port, used to build distributed connection
|
# The KV connector port, used to build distributed connection
|
||||||
kv_port: int = 14579
|
kv_port: int = 14579
|
||||||
|
|
||||||
|
# any extra config that the connector may need
|
||||||
|
kv_connector_extra_config: dict[str, Any] = {}
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
@@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel):
|
|||||||
return self.kv_connector is not None and \
|
return self.kv_connector is not None and \
|
||||||
self.kv_role in ["kv_consumer", "kv_both"]
|
self.kv_role in ["kv_consumer", "kv_both"]
|
||||||
|
|
||||||
|
def get_from_extra_config(self, key, default) -> Any:
|
||||||
|
return self.kv_connector_extra_config.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
class CompilationLevel:
|
class CompilationLevel:
|
||||||
# constants for the levels of the compilation process
|
# constants for the levels of the compilation process
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
- Distributed KV cache transmission using PyNccl pipes.
|
- Distributed KV cache transmission using PyNccl pipes.
|
||||||
- Non-blocking `insert`, blocking `drop_select`.
|
- Non-blocking `insert`, blocking `drop_select`.
|
||||||
- Use CPU signal pipe to avoid racing condition
|
- Use CPU signal pipe to avoid racing condition
|
||||||
- Handles buffer size constraints and provide backpressure mechanism to
|
- Handles buffer size constraints and provide backpressure mechanism to
|
||||||
stop the prefill instance when the decode instance is slow.
|
stop the prefill instance when the decode instance is slow.
|
||||||
"""
|
"""
|
||||||
import threading
|
import threading
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""
|
"""
|
||||||
This module implements a PyNccl pipe for sending and receiving
|
This module implements a PyNccl pipe for sending and receiving
|
||||||
Optional[torch.Tensor] between distributed ranks with advanced
|
Optional[torch.Tensor] between distributed ranks with advanced
|
||||||
communication features.
|
communication features.
|
||||||
|
|
||||||
Key Features:
|
Key Features:
|
||||||
@@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
self.device = self._select_device(device)
|
self.device = self._select_device(device)
|
||||||
|
|
||||||
# build distributed connection and send/recv implementation
|
# build distributed connection and send/recv implementation
|
||||||
|
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
|
||||||
self.group = StatelessProcessGroup.create(
|
self.group = StatelessProcessGroup.create(
|
||||||
host=self.config.kv_ip,
|
host=self.config.kv_ip,
|
||||||
port=self.config.kv_port + port_offset,
|
port=self.config.kv_port + port_offset,
|
||||||
rank=self.kv_rank,
|
rank=self.kv_rank,
|
||||||
world_size=self.kv_parallel_size,
|
world_size=self.kv_parallel_size,
|
||||||
|
store_timeout=store_timeout,
|
||||||
)
|
)
|
||||||
# add a barrier to make sure the connection is initiated properly
|
# add a barrier to make sure the connection is initiated properly
|
||||||
self.group.barrier()
|
self.group.barrier()
|
||||||
@@ -134,11 +136,11 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
Create a buffer to receive the tensor based on the provided metadata.
|
Create a buffer to receive the tensor based on the provided metadata.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- metadata: A dictionary with keys "dtype" and "shape", describing
|
- metadata: A dictionary with keys "dtype" and "shape", describing
|
||||||
the tensor's data type and shape.
|
the tensor's data type and shape.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- buffer: A tensor of the specified type and shape, allocated on
|
- buffer: A tensor of the specified type and shape, allocated on
|
||||||
self.device.
|
self.device.
|
||||||
"""
|
"""
|
||||||
return torch.empty(metadata["shape"],
|
return torch.empty(metadata["shape"],
|
||||||
@@ -159,18 +161,18 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
Receive the metadata dictionary from the target rank.
|
Receive the metadata dictionary from the target rank.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- metadata: A dictionary with keys "dtype" and "shape" describing
|
- metadata: A dictionary with keys "dtype" and "shape" describing
|
||||||
the tensor.
|
the tensor.
|
||||||
"""
|
"""
|
||||||
return self.group.recv_obj(self.target_rank_for_recv)
|
return self.group.recv_obj(self.target_rank_for_recv)
|
||||||
|
|
||||||
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
|
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
|
||||||
"""
|
"""
|
||||||
The actual implementation of sending the tensor and its metadata to the
|
The actual implementation of sending the tensor and its metadata to the
|
||||||
target rank.
|
target rank.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- tensor: The input tensor to be sent, or None if no tensor is
|
- tensor: The input tensor to be sent, or None if no tensor is
|
||||||
being sent.
|
being sent.
|
||||||
"""
|
"""
|
||||||
metadata = self._make_metadata(tensor)
|
metadata = self._make_metadata(tensor)
|
||||||
@@ -181,7 +183,7 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
|
|
||||||
def _recv_impl(self) -> Optional[torch.Tensor]:
|
def _recv_impl(self) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
The actual implementation of receiving a tensor and its metadata from
|
The actual implementation of receiving a tensor and its metadata from
|
||||||
the target rank.
|
the target rank.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -213,7 +215,7 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
|
|
||||||
def block_if_full(self):
|
def block_if_full(self):
|
||||||
"""
|
"""
|
||||||
Block the current thread if the buffer size is larger than the
|
Block the current thread if the buffer size is larger than the
|
||||||
threshold.
|
threshold.
|
||||||
"""
|
"""
|
||||||
while self.buffer_size > self.buffer_size_thresh:
|
while self.buffer_size > self.buffer_size_thresh:
|
||||||
@@ -222,7 +224,7 @@ class PyNcclPipe(KVPipeBase):
|
|||||||
|
|
||||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||||
"""
|
"""
|
||||||
Sends a tensor and its metadata to the destination rank in a
|
Sends a tensor and its metadata to the destination rank in a
|
||||||
non-blocking way.
|
non-blocking way.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import datetime
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -217,6 +218,7 @@ class StatelessProcessGroup:
|
|||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
data_expiration_seconds: int = 3600,
|
data_expiration_seconds: int = 3600,
|
||||||
|
store_timeout: int = 300,
|
||||||
) -> "StatelessProcessGroup":
|
) -> "StatelessProcessGroup":
|
||||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||||
pollute the global state.
|
pollute the global state.
|
||||||
@@ -238,6 +240,7 @@ class StatelessProcessGroup:
|
|||||||
port=port,
|
port=port,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
is_master=(rank == 0),
|
is_master=(rank == 0),
|
||||||
|
timeout=datetime.timedelta(seconds=store_timeout),
|
||||||
)
|
)
|
||||||
|
|
||||||
return StatelessProcessGroup(
|
return StatelessProcessGroup(
|
||||||
|
|||||||
Reference in New Issue
Block a user