[Kernels] Enable Torch Symmetric Memory All-Reduce By Default (#24111)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = {
|
||||
"10.0": {
|
||||
2: 2 * MiB, # 2 MB
|
||||
4: 2 * MiB, # 2 MB
|
||||
6: 2 * MiB, # 2 MB
|
||||
8: 2 * MiB, # 2 MB
|
||||
6: 1 * MiB, # 1 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
self.qr_comm: Optional[QuickAllReduce] = None
|
||||
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
symm_mem_enabled=(self.symm_mem_comm is not None
|
||||
and not self.symm_mem_comm.disabled),
|
||||
)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
@@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
# currently be an MI300 series.
|
||||
self.qr_comm = QuickAllReduce(group=self.cpu_group,
|
||||
device=self.device)
|
||||
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
|
||||
@@ -54,7 +54,8 @@ class CustomAllreduce:
|
||||
def __init__(self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 1024) -> None:
|
||||
max_size=8192 * 1024,
|
||||
symm_mem_enabled=False) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
@@ -111,7 +112,7 @@ class CustomAllreduce:
|
||||
self.device = device
|
||||
device_capability = current_platform.get_device_capability(
|
||||
).as_version_str()
|
||||
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
if (current_platform.is_cuda() and symm_mem_enabled
|
||||
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
|
||||
max_size = min(
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
|
||||
|
||||
@@ -27,8 +27,13 @@ class SymmMemCommunicator:
|
||||
"10.0": [6, 8],
|
||||
}
|
||||
|
||||
def __init__(self, group: ProcessGroup, device: Union[int, str,
|
||||
torch.device]):
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
# add options for testing
|
||||
force_multimem: Optional[bool] = None,
|
||||
max_size_override: Optional[int] = None):
|
||||
self.disabled = True
|
||||
|
||||
if not symm_mem_available:
|
||||
@@ -64,8 +69,17 @@ class SymmMemCommunicator:
|
||||
self.world_size,
|
||||
)
|
||||
return
|
||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
||||
self.world_size]
|
||||
# Use override max_size if provided, otherwise use default
|
||||
if max_size_override is not None:
|
||||
self.max_size = max_size_override
|
||||
logger.info(
|
||||
"SymmMemCommunicator: Using override max_size: %s bytes",
|
||||
self.max_size,
|
||||
)
|
||||
else:
|
||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
|
||||
self.device_capability][self.world_size]
|
||||
|
||||
self.buffer = torch_symm_mem.empty(
|
||||
self.max_size // self.dtype.itemsize,
|
||||
device=self.device,
|
||||
@@ -76,6 +90,7 @@ class SymmMemCommunicator:
|
||||
logger.warning("SymmMemCommunicator: symmetric memory "
|
||||
"multicast operations are not supported.")
|
||||
return
|
||||
self.force_multimem = force_multimem
|
||||
self.disabled = False
|
||||
|
||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||
@@ -98,8 +113,18 @@ class SymmMemCommunicator:
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
self.buffer[:inp.numel()].copy_(inp.view(-1))
|
||||
if self.world_size in self._WORLD_SIZES_MULTIMEM[
|
||||
self.device_capability]:
|
||||
|
||||
# Determine which algorithm to use
|
||||
use_multimem = False
|
||||
if self.force_multimem is not None:
|
||||
# Test override: use forced setting
|
||||
use_multimem = self.force_multimem
|
||||
else:
|
||||
# Normal logic: use multimem for supported world sizes
|
||||
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[
|
||||
self.device_capability]
|
||||
|
||||
if use_multimem:
|
||||
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
|
||||
"sum",
|
||||
self.group.group_name)
|
||||
|
||||
Reference in New Issue
Block a user