[Kernels] Enable Torch Symmetric Memory All-Reduce By Default (#24111)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Ilya Markov
2025-09-11 18:45:31 +02:00
committed by GitHub
parent bcbe2a4d9e
commit 1fdd5c42d7
7 changed files with 572 additions and 30 deletions

View File

@@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = {
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB
8: 2 * MiB, # 2 MB
6: 1 * MiB, # 1 MB
8: 1 * MiB, # 1 MB
}
}

View File

@@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
symm_mem_enabled=(self.symm_mem_comm is not None
and not self.symm_mem_comm.disabled),
)
if current_platform.is_rocm():
@@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND

View File

@@ -54,7 +54,8 @@ class CustomAllreduce:
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024) -> None:
max_size=8192 * 1024,
symm_mem_enabled=False) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
@@ -111,7 +112,7 @@ class CustomAllreduce:
self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
if (current_platform.is_cuda() and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],

View File

@@ -27,8 +27,13 @@ class SymmMemCommunicator:
"10.0": [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
# add options for testing
force_multimem: Optional[bool] = None,
max_size_override: Optional[int] = None):
self.disabled = True
if not symm_mem_available:
@@ -64,8 +69,17 @@ class SymmMemCommunicator:
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size]
# Use override max_size if provided, otherwise use default
if max_size_override is not None:
self.max_size = max_size_override
logger.info(
"SymmMemCommunicator: Using override max_size: %s bytes",
self.max_size,
)
else:
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability][self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
@@ -76,6 +90,7 @@ class SymmMemCommunicator:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.force_multimem = force_multimem
self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor):
@@ -98,8 +113,18 @@ class SymmMemCommunicator:
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]:
# Determine which algorithm to use
use_multimem = False
if self.force_multimem is not None:
# Test override: use forced setting
use_multimem = self.force_multimem
else:
# Normal logic: use multimem for supported world sizes
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]
if use_multimem:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)