Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -33,17 +33,19 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_sp_cpu: torch.Tensor,
|
||||
is_sequence_parallel: bool) -> torch.Tensor:
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
def naive_multicast(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_tokens_across_sp_cpu: torch.Tensor,
|
||||
is_sequence_parallel: bool,
|
||||
) -> torch.Tensor:
|
||||
assert len(x.shape) == 2
|
||||
buffer = torch.empty(
|
||||
(cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
world_size = (self.world_size
|
||||
if is_sequence_parallel else self.dp_world_size)
|
||||
world_size = self.world_size if is_sequence_parallel else self.dp_world_size
|
||||
|
||||
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[rank]
|
||||
@@ -59,24 +61,23 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
hidden_states = self.naive_multicast(
|
||||
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
router_logits = self.naive_multicast(
|
||||
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
@@ -107,13 +108,12 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
@@ -124,19 +124,16 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@@ -149,24 +146,35 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_pplx(
|
||||
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
|
||||
assert has_pplx(), (
|
||||
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels."
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init)
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: "
|
||||
"rank=%d, world size=%d", self.rank, self.world_size)
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
dist.broadcast(uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group)
|
||||
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if self.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
dist.broadcast(
|
||||
uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group,
|
||||
)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
@@ -174,21 +182,23 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx
|
||||
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
kwargs,
|
||||
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@@ -198,6 +208,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
@@ -208,8 +219,9 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_deep_ep(
|
||||
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
|
||||
assert has_deep_ep(), (
|
||||
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels."
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
@@ -224,13 +236,13 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@@ -260,23 +272,27 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
assert num_qps_per_rank is not None
|
||||
return dict(group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
return dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
|
||||
assert len(kwargs) == 0, (
|
||||
"DeepEPHTAll2AllManager expects no arguments. All the required "
|
||||
"args are computed in the Manager itself.")
|
||||
"args are computed in the Manager itself."
|
||||
)
|
||||
|
||||
import deep_ep
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs()
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
buffer_kwargs, deep_ep.Buffer
|
||||
)
|
||||
return handle
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
@@ -323,14 +339,17 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=num_ep_ranks,
|
||||
num_experts=num_global_experts)
|
||||
num_experts=num_global_experts,
|
||||
)
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
return dict(group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
return dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
"""
|
||||
@@ -338,10 +357,12 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
_make_all2all_kwargs.
|
||||
"""
|
||||
import deep_ep
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
buffer_kwargs, deep_ep.Buffer
|
||||
)
|
||||
return handle
|
||||
|
||||
# DeepEP LL uses RDMA so no SMs are used for communication
|
||||
@@ -355,12 +376,15 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_flashinfer_all2all(
|
||||
), "flashinfer all2all module not found. Please install/check flashinfer" # noqa
|
||||
assert has_flashinfer_all2all(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All "
|
||||
"rank=%d, world size=%d", self.rank, self.world_size)
|
||||
"Initialize for flashinfer All2All rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
self.initialized = False
|
||||
self.alltoall_info = None
|
||||
|
||||
@@ -375,8 +399,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
return
|
||||
|
||||
self.cleanup()
|
||||
logger.debug("making map: "
|
||||
"rank=%d, world size=%d", rank, world_size)
|
||||
logger.debug("making map: rank=%d, world size=%d", rank, world_size)
|
||||
self.mapping = Mapping(
|
||||
world_size,
|
||||
rank,
|
||||
@@ -385,25 +408,28 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.mnnvl_compat import (
|
||||
CustomCommunicator)
|
||||
CustomCommunicator,
|
||||
)
|
||||
|
||||
dp_config = MnnvlConfig(
|
||||
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
|
||||
fabric_page_size=1 << 29, # 512MB
|
||||
allocation_granularity=0 # Auto-detect
|
||||
allocation_granularity=0, # Auto-detect
|
||||
)
|
||||
|
||||
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
|
||||
self.mapping, dp_config)
|
||||
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config)
|
||||
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
|
||||
self.mapping, dp_config)
|
||||
self.mapping, dp_config
|
||||
)
|
||||
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.gpus_per_node = gpus_per_node
|
||||
self.initialized = True
|
||||
|
||||
logger.info("FlashInfer All2All initialized for rank %s, size %s",
|
||||
rank, world_size)
|
||||
logger.info(
|
||||
"FlashInfer All2All initialized for rank %s, size %s", rank, world_size
|
||||
)
|
||||
|
||||
def ensure_alltoall_workspace_initialized(self):
|
||||
"""Ensure workspace is initialized"""
|
||||
@@ -426,8 +452,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up workspace"""
|
||||
if self.initialized and self.workspace_tensor is not None \
|
||||
and self.prepare_workspace_tensor is not None:
|
||||
if (
|
||||
self.initialized
|
||||
and self.workspace_tensor is not None
|
||||
and self.prepare_workspace_tensor is not None
|
||||
):
|
||||
try:
|
||||
del self.workspace_tensor
|
||||
del self.prepare_workspace_tensor
|
||||
|
||||
@@ -19,8 +19,7 @@ import torch.multiprocessing as mp
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (cuda_device_count_stateless,
|
||||
update_environment_variables)
|
||||
from vllm.utils import cuda_device_count_stateless, update_environment_variables
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -39,7 +38,7 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = {
|
||||
4: 2 * MiB, # 2 MB
|
||||
6: 1 * MiB, # 1 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
@@ -54,7 +53,7 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 128 * MiB, # 128 MB
|
||||
8: 128 * MiB, # 128 MB
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
|
||||
@@ -63,14 +62,15 @@ NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
|
||||
4: 2 * MiB, # 2 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
},
|
||||
"always_use_above_world_size": 8 # Always use symm mem for world_size > 8
|
||||
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
|
||||
}
|
||||
|
||||
|
||||
def should_nccl_symm_mem_allreduce(world_size: int,
|
||||
input_tensor: torch.Tensor) -> bool:
|
||||
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled)
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
|
||||
if not is_symmetric_memory_enabled():
|
||||
return False
|
||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||
@@ -78,18 +78,18 @@ def should_nccl_symm_mem_allreduce(world_size: int,
|
||||
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
|
||||
if threshold is not None and input_tensor.nbytes >= threshold:
|
||||
return True
|
||||
return (world_size
|
||||
> NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"])
|
||||
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]
|
||||
|
||||
|
||||
def producer(batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None):
|
||||
def producer(
|
||||
batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for i in batch_src:
|
||||
@@ -115,14 +115,15 @@ def producer(batch_src: Sequence[int],
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def consumer(batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None):
|
||||
def consumer(
|
||||
batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for j in batch_tgt:
|
||||
@@ -198,12 +199,26 @@ def can_actually_p2p(
|
||||
producer_queue = smp.Queue()
|
||||
consumer_queue = smp.Queue()
|
||||
result_queue = smp.Queue()
|
||||
p_src = smp.Process(target=producer,
|
||||
args=(batch_src, producer_queue, consumer_queue,
|
||||
result_queue, cuda_visible_devices))
|
||||
p_tgt = smp.Process(target=consumer,
|
||||
args=(batch_tgt, producer_queue, consumer_queue,
|
||||
result_queue, cuda_visible_devices))
|
||||
p_src = smp.Process(
|
||||
target=producer,
|
||||
args=(
|
||||
batch_src,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_tgt = smp.Process(
|
||||
target=consumer,
|
||||
args=(
|
||||
batch_tgt,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_src.start()
|
||||
p_tgt.start()
|
||||
p_src.join()
|
||||
@@ -216,7 +231,10 @@ def can_actually_p2p(
|
||||
if a != b:
|
||||
logger.warning(
|
||||
"Two processes do not agree on the P2P access"
|
||||
" status on %d -> %d, treat as disabled.", src, tgt)
|
||||
" status on %d -> %d, treat as disabled.",
|
||||
src,
|
||||
tgt,
|
||||
)
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(a)
|
||||
@@ -255,12 +273,14 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
path = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
|
||||
envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
if ((not is_distributed or get_world_group().local_rank == 0)
|
||||
and (not os.path.exists(path))):
|
||||
|
||||
if (not is_distributed or get_world_group().local_rank == 0) and (
|
||||
not os.path.exists(path)
|
||||
):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
@@ -279,11 +299,10 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
# we don't use the output of the subprocess directly,
|
||||
# because the subprocess might produce logging output
|
||||
with tempfile.NamedTemporaryFile() as output_file:
|
||||
input_bytes = pickle.dumps(
|
||||
(batch_src, batch_tgt, output_file.name))
|
||||
returned = subprocess.run([sys.executable, __file__],
|
||||
input=input_bytes,
|
||||
capture_output=True)
|
||||
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
|
||||
returned = subprocess.run(
|
||||
[sys.executable, __file__], input=input_bytes, capture_output=True
|
||||
)
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
@@ -292,7 +311,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
raise RuntimeError(
|
||||
f"Error happened when batch testing "
|
||||
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
||||
f"{returned.stderr.decode()}") from e
|
||||
f"{returned.stderr.decode()}"
|
||||
) from e
|
||||
with open(output_file.name, "rb") as f:
|
||||
result = pickle.load(f)
|
||||
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||
|
||||
@@ -10,7 +10,6 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Cache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
@@ -35,9 +34,11 @@ class All2AllManagerBase:
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (get_dp_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as,
|
||||
)
|
||||
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
@@ -63,10 +64,12 @@ class All2AllManagerBase:
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
@@ -75,9 +78,7 @@ class All2AllManagerBase:
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False):
|
||||
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@@ -92,11 +93,13 @@ class DeviceCommunicatorBase:
|
||||
communication backend), the `device_group` will also be given.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
@@ -106,11 +109,11 @@ class DeviceCommunicatorBase:
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group,
|
||||
self.global_rank)
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
# as long as we use data parallel (coupled data parallel
|
||||
@@ -134,41 +137,39 @@ class DeviceCommunicatorBase:
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
output_size = (input_size[0] * self.world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim]
|
||||
+ (self.world_size * input_size[dim],)
|
||||
+ input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||
dim: int = 0,
|
||||
sizes: Optional[list[int]] = None
|
||||
sizes: Optional[list[int]] = None,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce_scatter(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
@@ -180,30 +181,28 @@ class DeviceCommunicatorBase:
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
||||
|
||||
output_tensor = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
output_tensor = torch.empty(
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
# Perform reduce-scatter operation
|
||||
torch.distributed.reduce_scatter_tensor(output_tensor,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
torch.distributed.reduce_scatter_tensor(
|
||||
output_tensor, input_tensor, group=self.device_group
|
||||
)
|
||||
|
||||
# Reshape before returning
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1,
|
||||
sizes: Optional[list[int]] = None) -> torch.Tensor:
|
||||
def reduce_scatterv(
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
@@ -211,7 +210,8 @@ class DeviceCommunicatorBase:
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
@@ -222,10 +222,9 @@ class DeviceCommunicatorBase:
|
||||
else:
|
||||
gather_list = None
|
||||
# Gather.
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
torch.distributed.gather(
|
||||
input_, gather_list, dst=self.ranks[dst], group=self.device_group
|
||||
)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
@@ -239,10 +238,9 @@ class DeviceCommunicatorBase:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(self,
|
||||
size: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
def recv(
|
||||
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
@@ -255,8 +253,7 @@ class DeviceCommunicatorBase:
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
def prepare_communication_buffer_for_model(self,
|
||||
model: torch.nn.Module) -> None:
|
||||
def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
@@ -264,11 +261,14 @@ class DeviceCommunicatorBase:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
module for module in model.modules()
|
||||
module
|
||||
for module in model.modules()
|
||||
# TODO(bnell): Should use isinstance but can't. Maybe search for
|
||||
# presence of quant_method.init_prepare_finalize?
|
||||
if (module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE")
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
@@ -277,7 +277,7 @@ class DeviceCommunicatorBase:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
@@ -285,9 +285,9 @@ class DeviceCommunicatorBase:
|
||||
"""
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
|
||||
@@ -15,30 +15,30 @@ from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
self.dist_module = torch.distributed
|
||||
|
||||
if (current_platform.get_cpu_architecture()
|
||||
== CpuArchEnum.X86) and hasattr(
|
||||
torch.ops._C,
|
||||
"init_shm_manager") and (unique_name.startswith("tp")
|
||||
or unique_name.startswith("pp")):
|
||||
if (
|
||||
(current_platform.get_cpu_architecture() == CpuArchEnum.X86)
|
||||
and hasattr(torch.ops._C, "init_shm_manager")
|
||||
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
|
||||
):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
@@ -46,7 +46,8 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
@@ -58,10 +59,9 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
gather_list = None
|
||||
|
||||
# Gather.
|
||||
self.dist_module.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
self.dist_module.gather(
|
||||
input_, gather_list, dst=self.ranks[dst], group=self.device_group
|
||||
)
|
||||
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
@@ -77,23 +77,24 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
output_size = (input_size[0] * self.world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
# All-gather.
|
||||
self.dist_module.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
self.dist_module.all_gather_into_tensor(
|
||||
output_tensor, input_, group=self.device_group
|
||||
)
|
||||
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim]
|
||||
+ (self.world_size * input_size[dim],)
|
||||
+ input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
def send_tensor_dict(
|
||||
@@ -111,7 +112,6 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
unique_name = communicator.unique_name
|
||||
@@ -139,24 +139,32 @@ class _CPUSHMDistributed:
|
||||
|
||||
return handle
|
||||
|
||||
def all_reduce(self,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
def all_reduce(
|
||||
self, input: torch.Tensor, group: Optional[ProcessGroup] = None
|
||||
) -> None:
|
||||
torch.ops._C.shm_allreduce(self.handle, input)
|
||||
|
||||
def gather(self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[list[torch.Tensor]],
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
def gather(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
gather_list: Optional[list[torch.Tensor]],
|
||||
dst: int = -1,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
) -> None:
|
||||
# Note: different from the torch gather, here we use local dst rank.
|
||||
torch.ops._C.shm_gather(self.handle, input, gather_list,
|
||||
torch.distributed.get_group_rank(group, dst))
|
||||
torch.ops._C.shm_gather(
|
||||
self.handle,
|
||||
input,
|
||||
gather_list,
|
||||
torch.distributed.get_group_rank(group, dst),
|
||||
)
|
||||
|
||||
def all_gather_into_tensor(self,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
def all_gather_into_tensor(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
) -> None:
|
||||
torch.ops._C.shm_all_gather(self.handle, input, output)
|
||||
|
||||
def send_tensor_dict(
|
||||
@@ -169,11 +177,11 @@ class _CPUSHMDistributed:
|
||||
size_list = []
|
||||
for v in value_list:
|
||||
if not isinstance(v, torch.Tensor):
|
||||
raise RuntimeError(
|
||||
"CpuCommunicator only supports sending tensors.")
|
||||
raise RuntimeError("CpuCommunicator only supports sending tensors.")
|
||||
size_list.append(v.size())
|
||||
key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]),
|
||||
dtype=torch.uint8)
|
||||
key_size_tensor = torch.frombuffer(
|
||||
pickle.dumps([key_list, size_list]), dtype=torch.uint8
|
||||
)
|
||||
value_list.append(key_size_tensor)
|
||||
|
||||
torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst)
|
||||
|
||||
@@ -8,11 +8,12 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
should_nccl_symm_mem_allreduce)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
register_nccl_symmetric_ops)
|
||||
should_nccl_symm_mem_allreduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled)
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -22,20 +23,21 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if "tp" not in unique_name:
|
||||
# custom allreduce or torch symm mem can be used only by tp
|
||||
use_custom_allreduce = False
|
||||
use_torch_symm_mem = False
|
||||
else:
|
||||
from vllm.distributed.parallel_state import (
|
||||
_ENABLE_CUSTOM_ALL_REDUCE)
|
||||
from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE
|
||||
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
|
||||
@@ -44,13 +46,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator)
|
||||
CustomAllreduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import (
|
||||
QuickAllReduce)
|
||||
from vllm.distributed.device_communicators.symm_mem import (
|
||||
SymmMemCommunicator)
|
||||
QuickAllReduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if self.world_size > 1:
|
||||
@@ -75,8 +77,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
symm_mem_enabled=(self.symm_mem_comm is not None
|
||||
and not self.symm_mem_comm.disabled),
|
||||
symm_mem_enabled=(
|
||||
self.symm_mem_comm is not None and not self.symm_mem_comm.disabled
|
||||
),
|
||||
)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
@@ -85,35 +88,39 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
|
||||
# If it's a rocm, 'use_custom_allreduce==True' means it must
|
||||
# currently be an MI300 series.
|
||||
self.qr_comm = QuickAllReduce(group=self.cpu_group,
|
||||
device=self.device)
|
||||
self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device)
|
||||
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
elif all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
logger.info("Using AllGather-ReduceScatter all2all manager.")
|
||||
elif all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
logger.info("Using PPLX all2all manager.")
|
||||
elif all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
logger.info("Using DeepEP High-Throughput all2all manager.")
|
||||
elif all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
logger.info("Using DeepEP Low-Latency all2all manager.")
|
||||
elif all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
self.all2all_manager = FlashInferAllToAllManager(
|
||||
self.cpu_group)
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||
logger.info("Using Flashinfer all2allv manager.")
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||
@@ -121,28 +128,34 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
def all_reduce(self, input_):
|
||||
# since currently we perform copy input -> symm_input -> out-of-place AR
|
||||
# return symm_output, we don't need to check if input is symmetric
|
||||
if self.pynccl_comm is not None and \
|
||||
should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_):
|
||||
if self.pynccl_comm is not None and should_nccl_symm_mem_allreduce(
|
||||
self.pynccl_comm.world_size, input_
|
||||
):
|
||||
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
# always try quick reduce first, then custom allreduce,
|
||||
# and then pynccl. (quick reduce just for ROCM MI3*)
|
||||
qr_comm = self.qr_comm
|
||||
if qr_comm is not None and not qr_comm.disabled and \
|
||||
qr_comm.should_quick_allreduce(input_):
|
||||
if (
|
||||
qr_comm is not None
|
||||
and not qr_comm.disabled
|
||||
and qr_comm.should_quick_allreduce(input_)
|
||||
):
|
||||
out = qr_comm.quick_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
ca_comm = self.ca_comm
|
||||
if ca_comm is not None and not ca_comm.disabled and \
|
||||
ca_comm.should_custom_ar(input_):
|
||||
if (
|
||||
ca_comm is not None
|
||||
and not ca_comm.disabled
|
||||
and ca_comm.should_custom_ar(input_)
|
||||
):
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
symm_mem_comm = self.symm_mem_comm
|
||||
if symm_mem_comm is not None and \
|
||||
symm_mem_comm.should_use_symm_mem(input_):
|
||||
if symm_mem_comm is not None and symm_mem_comm.should_use_symm_mem(input_):
|
||||
out = symm_mem_comm.all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
@@ -176,21 +189,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
output = torch.empty(
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(self,
|
||||
input_: torch.Tensor,
|
||||
dim: int = -1,
|
||||
sizes: Optional[list[int]] = None):
|
||||
def reduce_scatterv(
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
|
||||
):
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
@@ -209,11 +221,11 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
else:
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
||||
|
||||
output = torch.empty(output_shape,
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device)
|
||||
output = torch.empty(
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
if sizes is not None:
|
||||
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
|
||||
@@ -235,10 +247,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
else:
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(self,
|
||||
size: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
def recv(
|
||||
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
@@ -261,10 +272,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
|
||||
def all_gatherv(self,
|
||||
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||
dim: int = 0,
|
||||
sizes: Optional[list[int]] = None):
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||
dim: int = 0,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
if dim != 0:
|
||||
raise NotImplementedError("only dim 0 all-gatherv is supported")
|
||||
world_size = self.world_size
|
||||
@@ -276,20 +289,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
if sizes is not None and all(s == sizes[0] for s in sizes):
|
||||
sizes = None
|
||||
|
||||
def _all_gather_single(input_: torch.Tensor,
|
||||
sizes: Optional[list[int]] = None):
|
||||
def _all_gather_single(input_: torch.Tensor, sizes: Optional[list[int]] = None):
|
||||
input_size = input_.size()
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_.shape[dim] == sizes[self.rank_in_group], (
|
||||
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
|
||||
output_size = (sum(sizes), ) + input_size[1:]
|
||||
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}"
|
||||
)
|
||||
output_size = (sum(sizes),) + input_size[1:]
|
||||
else:
|
||||
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||
output_size = (input_size[0] * world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
if sizes is not None:
|
||||
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
|
||||
else:
|
||||
@@ -311,17 +324,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits, is_sequence_parallel)
|
||||
hidden_states, router_logits, is_sequence_parallel
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
hidden_states = self.all2all_manager.combine(
|
||||
hidden_states, is_sequence_parallel
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@@ -42,7 +42,7 @@ def find_loaded_library(lib_name) -> Optional[str]:
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
""" # noqa
|
||||
found = False
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
@@ -57,8 +57,9 @@ def find_loaded_library(lib_name) -> Optional[str]:
|
||||
start = line.index("/")
|
||||
path = line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(lib_name), \
|
||||
assert filename.rpartition(".so")[0].startswith(lib_name), (
|
||||
f"Unexpected filename: {filename} for library {lib_name}"
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@@ -70,30 +71,38 @@ class CudaRTLibrary:
|
||||
Function("cudaDeviceSynchronize", cudaError_t, []),
|
||||
# cudaError_t cudaDeviceReset ( void )
|
||||
Function("cudaDeviceReset", cudaError_t, []),
|
||||
|
||||
# const char* cudaGetErrorString ( cudaError_t error )
|
||||
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||
|
||||
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||
Function("cudaMalloc", cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
|
||||
Function(
|
||||
"cudaMalloc",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
|
||||
),
|
||||
# cudaError_t cudaFree ( void* devPtr )
|
||||
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
|
||||
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||
Function("cudaMemset", cudaError_t,
|
||||
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
|
||||
Function(
|
||||
"cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
|
||||
),
|
||||
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||
Function("cudaMemcpy", cudaError_t, [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
|
||||
]),
|
||||
|
||||
Function(
|
||||
"cudaMemcpy",
|
||||
cudaError_t,
|
||||
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
|
||||
),
|
||||
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||
Function("cudaIpcGetMemHandle", cudaError_t,
|
||||
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
|
||||
Function(
|
||||
"cudaIpcGetMemHandle",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
|
||||
),
|
||||
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||
Function("cudaIpcOpenMemHandle", cudaError_t, [
|
||||
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
|
||||
]),
|
||||
Function(
|
||||
"cudaIpcOpenMemHandle",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
|
||||
),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
@@ -109,11 +118,10 @@ class CudaRTLibrary:
|
||||
so_file = find_loaded_library("libcudart")
|
||||
if so_file is None:
|
||||
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
|
||||
assert so_file is not None, \
|
||||
(
|
||||
"libcudart is not loaded in the current process, "
|
||||
"try setting VLLM_CUDART_SO_PATH"
|
||||
)
|
||||
assert so_file is not None, (
|
||||
"libcudart is not loaded in the current process, "
|
||||
"try setting VLLM_CUDART_SO_PATH"
|
||||
)
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
@@ -154,27 +162,29 @@ class CudaRTLibrary:
|
||||
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
|
||||
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
|
||||
count: int) -> None:
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
|
||||
|
||||
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
|
||||
count: int) -> None:
|
||||
def cudaMemcpy(
|
||||
self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
|
||||
) -> None:
|
||||
cudaMemcpyDefault = 4
|
||||
kind = cudaMemcpyDefault
|
||||
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
|
||||
|
||||
def cudaIpcGetMemHandle(self,
|
||||
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
handle = cudaIpcMemHandle_t()
|
||||
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
|
||||
ctypes.byref(handle), devPtr))
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
|
||||
)
|
||||
return handle
|
||||
|
||||
def cudaIpcOpenMemHandle(self,
|
||||
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
cudaIpcMemLazyEnablePeerAccess = 1
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
|
||||
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaIpcOpenMemHandle"](
|
||||
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
|
||||
)
|
||||
)
|
||||
return devPtr
|
||||
|
||||
@@ -11,7 +11,9 @@ from torch.distributed import ProcessGroup
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES,
|
||||
gpu_p2p_access_check,
|
||||
)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -32,8 +34,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
if i == rank:
|
||||
continue
|
||||
if envs.VLLM_SKIP_P2P_CHECK:
|
||||
logger.info(
|
||||
"Skipping P2P check and trusting the driver's P2P report.")
|
||||
logger.info("Skipping P2P check and trusting the driver's P2P report.")
|
||||
return torch.cuda.can_device_access_peer(rank, i)
|
||||
if not gpu_p2p_access_check(rank, i):
|
||||
return False
|
||||
@@ -41,21 +42,23 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (inp.storage().nbytes() -
|
||||
inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size())
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 1024,
|
||||
symm_mem_enabled=False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=8192 * 1024,
|
||||
symm_mem_enabled=False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
@@ -72,20 +75,24 @@ class CustomAllreduce:
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-GPU environment
|
||||
logger.info("Custom allreduce is disabled because "
|
||||
"of missing custom allreduce library")
|
||||
logger.info(
|
||||
"Custom allreduce is disabled because "
|
||||
"of missing custom allreduce library"
|
||||
)
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||
"CustomAllreduce should be attached to a non-NCCL group."
|
||||
)
|
||||
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
" spans across nodes.")
|
||||
" spans across nodes."
|
||||
)
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
@@ -100,7 +107,9 @@ class CustomAllreduce:
|
||||
"Custom allreduce is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.",
|
||||
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
|
||||
world_size,
|
||||
str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
@@ -110,13 +119,15 @@ class CustomAllreduce:
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
device_capability = current_platform.get_device_capability(
|
||||
).as_version_str()
|
||||
if (current_platform.is_cuda() and symm_mem_enabled
|
||||
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
|
||||
device_capability = current_platform.get_device_capability().as_version_str()
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and symm_mem_enabled
|
||||
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES
|
||||
):
|
||||
max_size = min(
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
|
||||
max_size)
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size
|
||||
)
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
@@ -124,12 +135,9 @@ class CustomAllreduce:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id],
|
||||
dtype=torch.int,
|
||||
device="cpu")
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||
for _ in range(world_size)
|
||||
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
@@ -138,13 +146,13 @@ class CustomAllreduce:
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
assert current_platform.is_cuda_alike()
|
||||
fully_connected = current_platform.is_fully_connected(
|
||||
physical_device_ids)
|
||||
fully_connected = current_platform.is_fully_connected(physical_device_ids)
|
||||
if world_size > 2 and not fully_connected:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
" more than two PCIe-only GPUs. To silence this warning, "
|
||||
"specify disable_custom_all_reduce=True explicitly.")
|
||||
"specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
@@ -154,16 +162,17 @@ class CustomAllreduce:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.")
|
||||
"warning, specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Metadata composes of two parts: metadata for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
|
||||
group=group,
|
||||
uncached=True)
|
||||
self.meta_ptrs = self.create_shared_buffer(
|
||||
ops.meta_size() + max_size, group=group, uncached=True
|
||||
)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
@@ -172,21 +181,22 @@ class CustomAllreduce:
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(8 * 1024 * 1024,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.fully_connected = fully_connected
|
||||
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
|
||||
self.fully_connected)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta_ptrs, self.rank_data, rank, self.fully_connected
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
The main responsibility of this context manager is the
|
||||
The main responsibility of this context manager is the
|
||||
`register_graph_buffers` call at the end of the context.
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
@@ -204,15 +214,13 @@ class CustomAllreduce:
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [[None, None]
|
||||
for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(all_data[i],
|
||||
src=rank,
|
||||
group=self.group,
|
||||
device="cpu")
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
@@ -233,13 +241,11 @@ class CustomAllreduce:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
def all_reduce(self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: torch.Tensor = None,
|
||||
registered: bool = False):
|
||||
def all_reduce(
|
||||
self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
|
||||
):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
|
||||
If registered is True, this assumes inp's pointer is already
|
||||
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
||||
buffer.
|
||||
@@ -249,8 +255,9 @@ class CustomAllreduce:
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
|
||||
self.max_size)
|
||||
ops.all_reduce(
|
||||
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
|
||||
)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
@@ -283,9 +290,11 @@ class CustomAllreduce:
|
||||
self.close()
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
uncached: Optional[bool] = False) -> list[int]:
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
uncached: Optional[bool] = False,
|
||||
) -> list[int]:
|
||||
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
|
||||
|
||||
world_size = dist.get_world_size(group=group)
|
||||
@@ -302,9 +311,11 @@ class CustomAllreduce:
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = None) -> None:
|
||||
def free_shared_buffer(
|
||||
pointers: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
if ops is not None:
|
||||
|
||||
@@ -9,7 +9,6 @@ assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
|
||||
|
||||
|
||||
class CustomCommunicator(CommBackend):
|
||||
|
||||
def __init__(self, group):
|
||||
self._group = group
|
||||
|
||||
@@ -24,5 +23,5 @@ class CustomCommunicator(CommBackend):
|
||||
dist.all_gather_object(gathered, data, group=self._group)
|
||||
return gathered
|
||||
|
||||
def Split(self, color: int, key: int) -> 'CustomCommunicator':
|
||||
def Split(self, color: int, key: int) -> "CustomCommunicator":
|
||||
return self
|
||||
|
||||
@@ -10,8 +10,14 @@ from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum, ncclUniqueId)
|
||||
NCCLLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
ncclComm_t,
|
||||
ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum,
|
||||
ncclUniqueId,
|
||||
)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import current_stream
|
||||
@@ -23,7 +29,8 @@ _NCCL_SYMM_OPS_REGISTERED = False
|
||||
|
||||
def register_nccl_symmetric_ops(pynccl_comm):
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
nccl_symm_mem_context)
|
||||
nccl_symm_mem_context,
|
||||
)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global _NCCL_SYMM_OPS_REGISTERED
|
||||
@@ -31,8 +38,7 @@ def register_nccl_symmetric_ops(pynccl_comm):
|
||||
return
|
||||
_NCCL_SYMM_OPS_REGISTERED = True
|
||||
|
||||
def all_reduce_symmetric_with_copy_impl(
|
||||
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with nccl_symm_mem_context(pynccl_comm):
|
||||
symm_input = torch.empty_like(input_tensor)
|
||||
symm_output = torch.empty_like(input_tensor)
|
||||
@@ -40,8 +46,7 @@ def register_nccl_symmetric_ops(pynccl_comm):
|
||||
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
|
||||
return symm_output
|
||||
|
||||
def all_reduce_symmetric_with_copy_fake(
|
||||
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(input_tensor)
|
||||
|
||||
direct_register_custom_op(
|
||||
@@ -52,7 +57,6 @@ def register_nccl_symmetric_ops(pynccl_comm):
|
||||
|
||||
|
||||
class PyNcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
@@ -73,7 +77,8 @@ class PyNcclCommunicator:
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
||||
"PyNcclCommunicator should be attached to a non-NCCL group."
|
||||
)
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
@@ -132,7 +137,8 @@ class PyNcclCommunicator:
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
self.world_size, self.unique_id, self.rank
|
||||
)
|
||||
|
||||
stream = current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
@@ -141,11 +147,13 @@ class PyNcclCommunicator:
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
out_tensor: torch.Tensor = None,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
def all_reduce(
|
||||
self,
|
||||
in_tensor: torch.Tensor,
|
||||
out_tensor: torch.Tensor = None,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return None
|
||||
# nccl communicator created on a specific device
|
||||
@@ -153,25 +161,28 @@ class PyNcclCommunicator:
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
f"but the input tensor is on {in_tensor.device}"
|
||||
)
|
||||
|
||||
if out_tensor is None:
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
self.nccl.ncclAllReduce(
|
||||
buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
return out_tensor
|
||||
|
||||
def all_gather(self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
stream=None):
|
||||
def all_gather(
|
||||
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
@@ -179,14 +190,18 @@ class PyNcclCommunicator:
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
@@ -202,14 +217,15 @@ class PyNcclCommunicator:
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
assert output_tensor.shape[0] == sum(sizes)
|
||||
split_offset = 0
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
dst_slice = output_tensor[split_offset:split_offset + split_size]
|
||||
dst_slice = output_tensor[split_offset : split_offset + split_size]
|
||||
self.nccl.ncclBroadcast(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(dst_slice.data_ptr()),
|
||||
@@ -222,11 +238,13 @@ class PyNcclCommunicator:
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
def reduce_scatter(self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None):
|
||||
def reduce_scatter(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
@@ -234,15 +252,19 @@ class PyNcclCommunicator:
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def reduce_scatterv(
|
||||
self,
|
||||
@@ -259,20 +281,25 @@ class PyNcclCommunicator:
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
split_offset = 0
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
chunk = input_tensor[split_offset:split_offset + split_size, ...]
|
||||
chunk = input_tensor[split_offset : split_offset + split_size, ...]
|
||||
self.nccl.ncclReduce(
|
||||
buffer_type(chunk.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), chunk.numel(),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
chunk.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
@@ -281,31 +308,44 @@ class PyNcclCommunicator:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if src == self.rank:
|
||||
@@ -315,9 +355,15 @@ class PyNcclCommunicator:
|
||||
else:
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
self.nccl.ncclBroadcast(
|
||||
sendbuff,
|
||||
recvbuff,
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def group_start(self):
|
||||
self.nccl.ncclGroupStart()
|
||||
@@ -334,8 +380,7 @@ class PyNcclCommunicator:
|
||||
)
|
||||
|
||||
def register_comm_window_raw(self, ptr: int, size: int):
|
||||
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
|
||||
size, 1)
|
||||
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
@@ -98,7 +98,9 @@ def compile_nccl_allocator():
|
||||
"This is expected if NCCL headers are not available. "
|
||||
"optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory "
|
||||
"containing the NCCL header. "
|
||||
"Error: %s", str(e))
|
||||
"Error: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
|
||||
def get_nccl_mem_pool():
|
||||
@@ -125,21 +127,24 @@ atexit.register(_cleanup_nccl_allocator_wrapper)
|
||||
|
||||
|
||||
class nccl_symm_mem_context:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pynccl_comm: PyNcclCommunicator,
|
||||
disabled: bool = False,
|
||||
):
|
||||
self.disabled = (disabled or not is_symmetric_memory_enabled()
|
||||
or pynccl_comm.world_size == 1
|
||||
or not current_platform.is_cuda()
|
||||
or get_nccl_mem_pool() is None or version.parse(
|
||||
torch.__version__) < version.parse("2.8.0.a0"))
|
||||
self.disabled = (
|
||||
disabled
|
||||
or not is_symmetric_memory_enabled()
|
||||
or pynccl_comm.world_size == 1
|
||||
or not current_platform.is_cuda()
|
||||
or get_nccl_mem_pool() is None
|
||||
or version.parse(torch.__version__) < version.parse("2.8.0.a0")
|
||||
)
|
||||
if self.disabled:
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
self._mem_pool_ctx: contextlib.AbstractContextManager[
|
||||
Any] = contextlib.nullcontext()
|
||||
self._mem_pool_ctx: contextlib.AbstractContextManager[Any] = (
|
||||
contextlib.nullcontext()
|
||||
)
|
||||
self.is_graph_capture = None
|
||||
self.device = None
|
||||
else:
|
||||
@@ -151,16 +156,16 @@ class nccl_symm_mem_context:
|
||||
def __enter__(self):
|
||||
if self.disabled:
|
||||
return self
|
||||
assert (
|
||||
self.pynccl_comm
|
||||
is not None), "Symmetric memory requires pynccl to be initalized"
|
||||
assert (
|
||||
self.pynccl_comm.nccl_version >= 22703
|
||||
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
||||
assert self.pynccl_comm is not None, (
|
||||
"Symmetric memory requires pynccl to be initalized"
|
||||
)
|
||||
assert self.pynccl_comm.nccl_version >= 22703, (
|
||||
"NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
||||
)
|
||||
if self.is_graph_capture:
|
||||
assert (
|
||||
_graph_pool_id
|
||||
is not None), "graph_pool_id is not set under graph capture"
|
||||
assert _graph_pool_id is not None, (
|
||||
"graph_pool_id is not set under graph capture"
|
||||
)
|
||||
# Pause graph memory pool to use symmetric memory with cuda graph
|
||||
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
|
||||
self._mem_pool_ctx.__enter__()
|
||||
@@ -179,8 +184,8 @@ class nccl_symm_mem_context:
|
||||
for segment in _cached_pool_snapshot:
|
||||
if segment["address"] not in _registered_base_addrs:
|
||||
self.pynccl_comm.register_comm_window_raw(
|
||||
segment["address"], segment["total_size"])
|
||||
segment["address"], segment["total_size"]
|
||||
)
|
||||
_registered_base_addrs.add(segment["address"])
|
||||
if self.is_graph_capture:
|
||||
torch._C._cuda_beginAllocateCurrentThreadToPool(
|
||||
self.device, _graph_pool_id)
|
||||
torch._C._cuda_beginAllocateCurrentThreadToPool(self.device, _graph_pool_id)
|
||||
|
||||
@@ -133,88 +133,141 @@ class NCCLLibrary:
|
||||
# const char* ncclGetErrorString(ncclResult_t result)
|
||||
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
Function("ncclGetVersion", ncclResult_t,
|
||||
[ctypes.POINTER(ctypes.c_int)]),
|
||||
Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
Function("ncclGetUniqueId", ncclResult_t,
|
||||
[ctypes.POINTER(ncclUniqueId)]),
|
||||
Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function("ncclCommInitRank", ncclResult_t, [
|
||||
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
|
||||
ctypes.c_int
|
||||
]),
|
||||
Function(
|
||||
"ncclCommInitRank",
|
||||
ncclResult_t,
|
||||
[ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],
|
||||
),
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("ncclAllReduce", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
Function(
|
||||
"ncclAllReduce",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, int root,
|
||||
# ncclComm_t comm, cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("ncclReduce", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
Function(
|
||||
"ncclReduce",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclAllGather(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("ncclAllGather", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
Function(
|
||||
"ncclAllGather",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("ncclReduceScatter", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
Function(
|
||||
"ncclReduceScatter",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclSend(
|
||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("ncclSend", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
Function(
|
||||
"ncclSend",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclRecv(
|
||||
# void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
# int src, ncclComm_t comm, cudaStream_t stream);
|
||||
Function("ncclRecv", ncclResult_t, [
|
||||
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
Function(
|
||||
"ncclRecv",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclBroadcast(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
Function("ncclBroadcast", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ctypes.c_int, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
Function(
|
||||
"ncclBroadcast",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
@@ -241,8 +294,7 @@ class NCCLLibrary:
|
||||
),
|
||||
# ncclResult_t ncclCommWindowDeregister(
|
||||
# ncclComm_t comm, ncclWindow_t win);
|
||||
Function("ncclCommWindowDeregister", ncclResult_t,
|
||||
[ncclComm_t, ncclWindow_t]),
|
||||
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
@@ -254,7 +306,6 @@ class NCCLLibrary:
|
||||
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
@@ -270,8 +321,10 @@ class NCCLLibrary:
|
||||
"or it does not support the current platform %s. "
|
||||
"If you already have the library, please set the "
|
||||
"environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.", so_file,
|
||||
platform.platform())
|
||||
" to point to the correct nccl library path.",
|
||||
so_file,
|
||||
platform.platform(),
|
||||
)
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
@@ -284,15 +337,18 @@ class NCCLLibrary:
|
||||
_funcs[func.name] = f
|
||||
except AttributeError:
|
||||
if func.name in [
|
||||
"ncclCommWindowRegister",
|
||||
"ncclCommWindowDeregister"
|
||||
"ncclCommWindowRegister",
|
||||
"ncclCommWindowDeregister",
|
||||
]:
|
||||
if envs.VLLM_USE_NCCL_SYMM_MEM:
|
||||
logger.warning_once(
|
||||
"The symbol %s is not found in the NCCL "
|
||||
"library %s. To enable VLLM_USE_NCCL_SYMM_MEM "
|
||||
" please update your NCCL version to >= "
|
||||
"2.27.03.", func.name, so_file)
|
||||
"2.27.03.",
|
||||
func.name,
|
||||
so_file,
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
# Having an exception here on ROCm platform is
|
||||
# not allowed during graph capturing
|
||||
@@ -325,88 +381,153 @@ class NCCLLibrary:
|
||||
|
||||
def ncclGetUniqueId(self) -> ncclUniqueId:
|
||||
unique_id = ncclUniqueId()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
|
||||
ctypes.byref(unique_id)))
|
||||
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
|
||||
if len(data) != 128:
|
||||
raise ValueError(
|
||||
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes")
|
||||
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes"
|
||||
)
|
||||
unique_id = ncclUniqueId()
|
||||
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
|
||||
return unique_id
|
||||
|
||||
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
|
||||
rank: int) -> ncclComm_t:
|
||||
def ncclCommInitRank(
|
||||
self, world_size: int, unique_id: ncclUniqueId, rank: int
|
||||
) -> ncclComm_t:
|
||||
comm = ncclComm_t()
|
||||
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
|
||||
world_size, unique_id,
|
||||
rank))
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclCommInitRank"](
|
||||
ctypes.byref(comm), world_size, unique_id, rank
|
||||
)
|
||||
)
|
||||
return comm
|
||||
|
||||
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
def ncclAllReduce(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclAllReduce"](
|
||||
sendbuff, recvbuff, count, datatype, op, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, root: int,
|
||||
comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
def ncclReduce(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
root: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, root, comm,
|
||||
stream))
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclReduce"](
|
||||
sendbuff, recvbuff, count, datatype, op, root, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
def ncclReduceScatter(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
|
||||
count, datatype, op,
|
||||
comm, stream))
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclReduceScatter"](
|
||||
sendbuff, recvbuff, count, datatype, op, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
def ncclAllGather(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# which is an aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
|
||||
datatype, comm, stream))
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclAllGather"](
|
||||
sendbuff, recvbuff, count, datatype, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
|
||||
dest, comm, stream))
|
||||
def ncclSend(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
dest: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)
|
||||
)
|
||||
|
||||
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
|
||||
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
|
||||
comm, stream))
|
||||
def ncclRecv(
|
||||
self,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
src: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
|
||||
)
|
||||
|
||||
def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, root: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count,
|
||||
datatype, root, comm,
|
||||
stream))
|
||||
def ncclBroadcast(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
root: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclBroadcast"](
|
||||
sendbuff, recvbuff, count, datatype, root, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
@@ -417,19 +538,27 @@ class NCCLLibrary:
|
||||
def ncclGroupEnd(self) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
|
||||
|
||||
def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type,
|
||||
size: int, win_flags: int) -> ncclWindow_t:
|
||||
def ncclCommWindowRegister(
|
||||
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
|
||||
) -> ncclWindow_t:
|
||||
window = ncclWindow_t()
|
||||
self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"](
|
||||
comm, buff, size, ctypes.byref(window), win_flags))
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclCommWindowRegister"](
|
||||
comm, buff, size, ctypes.byref(window), win_flags
|
||||
)
|
||||
)
|
||||
return window
|
||||
|
||||
def ncclCommWindowDeregister(self, comm: ncclComm_t,
|
||||
window: ncclWindow_t) -> None:
|
||||
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||
"ncclComm_t", "cudaStream_t", "buffer_type"
|
||||
"NCCLLibrary",
|
||||
"ncclDataTypeEnum",
|
||||
"ncclRedOpTypeEnum",
|
||||
"ncclUniqueId",
|
||||
"ncclComm_t",
|
||||
"cudaStream_t",
|
||||
"buffer_type",
|
||||
]
|
||||
|
||||
@@ -27,9 +27,10 @@ except Exception:
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (inp.storage().nbytes() -
|
||||
inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size())
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
class QuickReduceRegime(Enum):
|
||||
@@ -44,7 +45,6 @@ MB = 1024 * 1024
|
||||
|
||||
|
||||
class QuickAllReduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
|
||||
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||
# The following data is based on kernel tests.
|
||||
@@ -58,20 +58,21 @@ class QuickAllReduce:
|
||||
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
|
||||
}
|
||||
|
||||
def __init__(self, group: ProcessGroup,
|
||||
device: Union[int, str, torch.device]) -> None:
|
||||
def __init__(
|
||||
self, group: ProcessGroup, device: Union[int, str, torch.device]
|
||||
) -> None:
|
||||
"""
|
||||
Custom allreduce provides non-destructive acceleration and is
|
||||
Custom allreduce provides non-destructive acceleration and is
|
||||
available for CUDA and ROCm MI300 series.
|
||||
|
||||
Custom quick allreduce leverages quantization for further
|
||||
acceleration on ROCm. It currently supports Q8, Q6, and Q4
|
||||
Custom quick allreduce leverages quantization for further
|
||||
acceleration on ROCm. It currently supports Q8, Q6, and Q4
|
||||
quantization formats and FP(float16, bfloat16).
|
||||
|
||||
Quick allreduce is designed as a complement to custom allreduce.
|
||||
Its initialization requires even stricter conditions.
|
||||
Quick allreduce is designed as a complement to custom allreduce.
|
||||
Its initialization requires even stricter conditions.
|
||||
|
||||
Only the ROCm MI300 series is supported for quick allreduce at
|
||||
Only the ROCm MI300 series is supported for quick allreduce at
|
||||
this time.
|
||||
|
||||
Args:
|
||||
@@ -93,18 +94,23 @@ class QuickAllReduce:
|
||||
if not quick_ar:
|
||||
# disable because of missing quick reduce library
|
||||
# e.g. in a cuda environment
|
||||
logger.info("Custom quick allreduce is disabled because "
|
||||
"of missing custom quick allreduce library")
|
||||
logger.info(
|
||||
"Custom quick allreduce is disabled because "
|
||||
"of missing custom quick allreduce library"
|
||||
)
|
||||
return
|
||||
|
||||
self.group = group
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"Custom quick allreduce should be attached to a non-NCCL group.")
|
||||
"Custom quick allreduce should be attached to a non-NCCL group."
|
||||
)
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom quick allreduce for
|
||||
# multi-node case.
|
||||
logger.warning("Custom quick allreduce is disabled because this "
|
||||
"process group spans across nodes.")
|
||||
logger.warning(
|
||||
"Custom quick allreduce is disabled because this "
|
||||
"process group spans across nodes."
|
||||
)
|
||||
return
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
@@ -118,7 +124,9 @@ class QuickAllReduce:
|
||||
logger.warning(
|
||||
"Custom quick allreduce is disabled due to an "
|
||||
"unsupported world size: %d. Supported world sizes: %s.",
|
||||
world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES))
|
||||
world_size,
|
||||
str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
@@ -134,9 +142,7 @@ class QuickAllReduce:
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id],
|
||||
dtype=torch.int,
|
||||
device="cpu")
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||
for _ in range(self.world_size)
|
||||
@@ -148,12 +154,12 @@ class QuickAllReduce:
|
||||
# where custom quick allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
assert current_platform.is_cuda_alike()
|
||||
self.fully_connected = current_platform.is_fully_connected(
|
||||
physical_device_ids)
|
||||
self.fully_connected = current_platform.is_fully_connected(physical_device_ids)
|
||||
if self.world_size > 2 and not self.fully_connected:
|
||||
logger.debug(
|
||||
"Custom quick allreduce is disabled because it's not supported "
|
||||
"on more than two PCIe-only GPUs. ")
|
||||
"on more than two PCIe-only GPUs. "
|
||||
)
|
||||
return
|
||||
|
||||
self.init_quick_all_reduce()
|
||||
@@ -169,24 +175,31 @@ class QuickAllReduce:
|
||||
"Custom quick allreduce:",
|
||||
f"Invalid quantization level: {regime_str}. "
|
||||
"Supported levels: "
|
||||
f"{list(QuickReduceRegime.__members__.keys())}")
|
||||
f"{list(QuickReduceRegime.__members__.keys())}",
|
||||
)
|
||||
return
|
||||
|
||||
if regime_str == "NONE":
|
||||
logger.debug("Custom quick allreduce is disabled based "
|
||||
"on env variable "
|
||||
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'")
|
||||
logger.debug(
|
||||
"Custom quick allreduce is disabled based "
|
||||
"on env variable "
|
||||
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
|
||||
)
|
||||
return
|
||||
self.qr_quant_level = QuickReduceRegime[regime_str]
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config is not None and \
|
||||
hasattr(vllm_config, "model_config") and \
|
||||
hasattr(vllm_config.model_config, "dtype"):
|
||||
if (
|
||||
vllm_config is not None
|
||||
and hasattr(vllm_config, "model_config")
|
||||
and hasattr(vllm_config.model_config, "dtype")
|
||||
):
|
||||
dtype = vllm_config.model_config.dtype
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
logger.debug(
|
||||
"Custom quick allreduce disabled: only supports "
|
||||
"float16 and float16, but get %s.", dtype)
|
||||
"float16 and float16, but get %s.",
|
||||
dtype,
|
||||
)
|
||||
return
|
||||
|
||||
if dtype == torch.bfloat16 and self.use_fp16_kernels:
|
||||
@@ -194,7 +207,8 @@ class QuickAllReduce:
|
||||
"Custom quick allreduce: BF16 inputs will be converted "
|
||||
"to FP16 to improve performance. set "
|
||||
"envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 "
|
||||
"to turn off.")
|
||||
"to turn off."
|
||||
)
|
||||
|
||||
# VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
|
||||
qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
|
||||
@@ -206,8 +220,7 @@ class QuickAllReduce:
|
||||
)
|
||||
qr_max_size = qr_max_size * MB
|
||||
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
|
||||
self.qr_max_size = qr_max_size if qr_max_size is not None \
|
||||
else ops.qr_max_size()
|
||||
self.qr_max_size = qr_max_size if qr_max_size is not None else ops.qr_max_size()
|
||||
self.create_shared_buffer()
|
||||
self.disabled = False
|
||||
|
||||
@@ -217,16 +230,15 @@ class QuickAllReduce:
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
gcn_arch = getattr(props, "gcnArchName", "")
|
||||
supported_archs = ['gfx94', 'gfx95']
|
||||
supported_archs = ["gfx94", "gfx95"]
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to determine ROCm for quick allreduce: %s",
|
||||
e)
|
||||
logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
|
||||
return False
|
||||
|
||||
def create_shared_buffer(self):
|
||||
"""
|
||||
Creates a shared buffer for quickreduce.
|
||||
Creates a shared buffer for quickreduce.
|
||||
Has to be called after init_custom_qr
|
||||
"""
|
||||
handle = ops.qr_get_handle(self._ptr)
|
||||
@@ -253,9 +265,11 @@ class QuickAllReduce:
|
||||
dtype = inp.dtype
|
||||
if self.use_fp16_kernels:
|
||||
dtype = torch.float16
|
||||
return inp_size <= self.qr_max_size and \
|
||||
inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\
|
||||
[self.qr_quant_level.value]
|
||||
return (
|
||||
inp_size <= self.qr_max_size
|
||||
and inp_size
|
||||
>= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
|
||||
)
|
||||
|
||||
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
|
||||
"""Performs an out-of-place custom quick all reduce."""
|
||||
@@ -263,8 +277,9 @@ class QuickAllReduce:
|
||||
# as QR uses static IPC buffer.
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value,
|
||||
self.use_fp16_kernels)
|
||||
ops.qr_all_reduce(
|
||||
self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
|
||||
)
|
||||
return out
|
||||
|
||||
def close(self):
|
||||
|
||||
@@ -6,12 +6,12 @@ from typing import Any, Optional
|
||||
import ray
|
||||
import torch
|
||||
from ray.exceptions import RayChannelError
|
||||
from ray.experimental.channel.communicator import (Communicator,
|
||||
TorchTensorAllocator)
|
||||
from ray.experimental.channel.communicator import Communicator, TorchTensorAllocator
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase)
|
||||
DeviceCommunicatorBase,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import current_stream
|
||||
@@ -59,11 +59,11 @@ class RayPPCommunicator(Communicator):
|
||||
self._rank: Optional[int] = None
|
||||
self._actor_handles = actor_handles
|
||||
if use_communication_streams:
|
||||
raise NotImplementedError(
|
||||
"use_communication_streams is not supported")
|
||||
raise NotImplementedError("use_communication_streams is not supported")
|
||||
if cuda_stream is not None and cuda_stream != current_stream():
|
||||
raise ValueError(
|
||||
"cuda_stream other than the current stream is not supported")
|
||||
"cuda_stream other than the current stream is not supported"
|
||||
)
|
||||
|
||||
if rank is not None:
|
||||
# Rank is not None, this is Ray worker
|
||||
@@ -99,13 +99,14 @@ class RayPPCommunicator(Communicator):
|
||||
|
||||
# Ray actor IDs are 32-character hex strings (128 bits)
|
||||
ACTOR_ID_LEN = 32
|
||||
actor_id_bytes = actor_id_str.encode('utf-8')
|
||||
assert len(
|
||||
actor_id_bytes
|
||||
) == ACTOR_ID_LEN, f"Unexpected actor ID length: {len(actor_id_bytes)}"
|
||||
actor_id_bytes = actor_id_str.encode("utf-8")
|
||||
assert len(actor_id_bytes) == ACTOR_ID_LEN, (
|
||||
f"Unexpected actor ID length: {len(actor_id_bytes)}"
|
||||
)
|
||||
|
||||
actor_id_tensor = torch.frombuffer(
|
||||
actor_id_bytes, dtype=torch.uint8).to(self._comm.device)
|
||||
actor_id_tensor = torch.frombuffer(actor_id_bytes, dtype=torch.uint8).to(
|
||||
self._comm.device
|
||||
)
|
||||
|
||||
# All-gather full actor IDs from all actors
|
||||
gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0)
|
||||
@@ -115,9 +116,8 @@ class RayPPCommunicator(Communicator):
|
||||
for rank in range(self._world_size):
|
||||
start_idx = rank * ACTOR_ID_LEN
|
||||
end_idx = (rank + 1) * ACTOR_ID_LEN
|
||||
actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy(
|
||||
).tobytes()
|
||||
actor_id = actor_bytes.decode('utf-8')
|
||||
actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy().tobytes()
|
||||
actor_id = actor_bytes.decode("utf-8")
|
||||
self._actor_id_to_rank[actor_id] = rank
|
||||
|
||||
def initialize(self, rank: int) -> None:
|
||||
@@ -131,9 +131,10 @@ class RayPPCommunicator(Communicator):
|
||||
"""
|
||||
Return the given actor's rank using device communicator collective ops.
|
||||
"""
|
||||
assert hasattr(self, '_actor_id_to_rank'), (
|
||||
assert hasattr(self, "_actor_id_to_rank"), (
|
||||
"Actor rank mapping not built. "
|
||||
"This should have been done during initialization.")
|
||||
"This should have been done during initialization."
|
||||
)
|
||||
|
||||
actor_id_str = actor._actor_id.hex()
|
||||
|
||||
|
||||
@@ -14,14 +14,24 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import zmq
|
||||
from torch.distributed import ProcessGroup
|
||||
from zmq import IPV6 # type: ignore
|
||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
from zmq import ( # type: ignore
|
||||
IPV6, # type: ignore
|
||||
SUB,
|
||||
SUBSCRIBE,
|
||||
XPUB,
|
||||
XPUB_VERBOSE,
|
||||
Context,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address)
|
||||
from vllm.utils import (
|
||||
get_ip,
|
||||
get_open_port,
|
||||
get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address,
|
||||
)
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
@@ -29,7 +39,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SpinTimer:
|
||||
|
||||
def record_activity(self):
|
||||
pass
|
||||
|
||||
@@ -66,12 +75,13 @@ class SpinSleepTimer(SpinTimer):
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
|
||||
def __init__(self,
|
||||
n_reader: int,
|
||||
max_chunk_bytes: int,
|
||||
max_chunks: int,
|
||||
name: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
n_reader: int,
|
||||
max_chunk_bytes: int,
|
||||
max_chunks: int,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
A shared memory ring buffer implementation for broadcast communication.
|
||||
Essentially, it is a queue where only one will `enqueue` and multiple
|
||||
@@ -120,13 +130,14 @@ class ShmRingBuffer:
|
||||
created object to other processes by pickling it. The other processes will
|
||||
get the name of the shared memory and open it, so that they can access the
|
||||
same shared memory buffer.
|
||||
"""# noqa
|
||||
""" # noqa
|
||||
self.n_reader = n_reader
|
||||
self.metadata_size = 1 + n_reader
|
||||
self.max_chunk_bytes = max_chunk_bytes
|
||||
self.max_chunks = max_chunks
|
||||
self.total_bytes_of_buffer = (self.max_chunk_bytes +
|
||||
self.metadata_size) * self.max_chunks
|
||||
self.total_bytes_of_buffer = (
|
||||
self.max_chunk_bytes + self.metadata_size
|
||||
) * self.max_chunks
|
||||
self.data_offset = 0
|
||||
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
|
||||
|
||||
@@ -134,10 +145,10 @@ class ShmRingBuffer:
|
||||
# we are creating a buffer
|
||||
self.is_creator = True
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=self.total_bytes_of_buffer)
|
||||
create=True, size=self.total_bytes_of_buffer
|
||||
)
|
||||
# initialize the metadata section to 0
|
||||
with self.shared_memory.buf[self.
|
||||
metadata_offset:] as metadata_buffer:
|
||||
with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
|
||||
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
|
||||
else:
|
||||
# we are opening an existing buffer
|
||||
@@ -145,8 +156,10 @@ class ShmRingBuffer:
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch("multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None):
|
||||
with patch(
|
||||
"multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None,
|
||||
):
|
||||
try:
|
||||
self.shared_memory = shared_memory.SharedMemory(name=name)
|
||||
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
|
||||
@@ -154,8 +167,7 @@ class ShmRingBuffer:
|
||||
# so the shared memory block size may be larger or equal
|
||||
# to the requested size. The size parameter is ignored
|
||||
# when attaching to an existing block.
|
||||
assert (self.shared_memory.size
|
||||
>= self.total_bytes_of_buffer)
|
||||
assert self.shared_memory.size >= self.total_bytes_of_buffer
|
||||
except FileNotFoundError:
|
||||
# we might deserialize the object in a different node
|
||||
# in this case, this object is not used,
|
||||
@@ -163,8 +175,12 @@ class ShmRingBuffer:
|
||||
pass
|
||||
|
||||
def handle(self):
|
||||
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
|
||||
self.shared_memory.name)
|
||||
return (
|
||||
self.n_reader,
|
||||
self.max_chunk_bytes,
|
||||
self.max_chunks,
|
||||
self.shared_memory.name,
|
||||
)
|
||||
|
||||
def __reduce__(self):
|
||||
return (
|
||||
@@ -204,7 +220,6 @@ class Handle:
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
@@ -228,8 +243,7 @@ class MessageQueue:
|
||||
# for local readers, we will:
|
||||
# 1. create a shared memory ring buffer to communicate small data
|
||||
# 2. create a publish-subscribe socket to communicate large data
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||
max_chunks)
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
|
||||
|
||||
# XPUB is very similar to PUB,
|
||||
# except that it can receive subscription messages
|
||||
@@ -279,8 +293,7 @@ class MessageQueue:
|
||||
|
||||
self.handle = Handle(
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer_handle=self.buffer.handle()
|
||||
if self.buffer is not None else None,
|
||||
buffer_handle=self.buffer.handle() if self.buffer is not None else None,
|
||||
local_subscribe_addr=local_subscribe_addr,
|
||||
remote_subscribe_addr=remote_subscribe_addr,
|
||||
remote_addr_ipv6=remote_addr_ipv6,
|
||||
@@ -315,8 +328,9 @@ class MessageQueue:
|
||||
|
||||
self.remote_socket = None
|
||||
|
||||
self._read_spin_timer = SpinSleepTimer(
|
||||
) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
|
||||
self._read_spin_timer = (
|
||||
SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
|
||||
)
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
self.current_idx = -1
|
||||
@@ -399,7 +413,8 @@ class MessageQueue:
|
||||
" in %s seconds. This typically happens when some"
|
||||
" processes are hanging or doing some"
|
||||
" time-consuming work (e.g. compilation)",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
continue
|
||||
@@ -423,15 +438,16 @@ class MessageQueue:
|
||||
metadata_buffer[i] = 0
|
||||
# mark the block as written
|
||||
metadata_buffer[0] = 1
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None,
|
||||
indefinite: bool = False):
|
||||
def acquire_read(
|
||||
self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None,
|
||||
indefinite: bool = False,
|
||||
):
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
@@ -460,15 +476,16 @@ class MessageQueue:
|
||||
raise TimeoutError
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if not indefinite and (elapsed
|
||||
> VLLM_RINGBUFFER_WARNING_INTERVAL *
|
||||
n_warning):
|
||||
if not indefinite and (
|
||||
elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
|
||||
):
|
||||
logger.info(
|
||||
"No available shared memory broadcast block found"
|
||||
" in %s seconds. This typically happens when some"
|
||||
" processes are hanging or doing some"
|
||||
" time-consuming work (e.g. compilation).",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
continue
|
||||
@@ -480,14 +497,13 @@ class MessageQueue:
|
||||
# caller has read from the buffer
|
||||
# set the read flag
|
||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||
|
||||
self._read_spin_timer.record_activity()
|
||||
break
|
||||
|
||||
def enqueue(self, obj, timeout: Optional[float] = None):
|
||||
""" Write to message queue with optional timeout (in seconds) """
|
||||
"""Write to message queue with optional timeout (in seconds)"""
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if self.n_local_reader > 0:
|
||||
@@ -498,15 +514,17 @@ class MessageQueue:
|
||||
else:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 0 # not overflow
|
||||
buf[1:len(serialized_obj) + 1] = serialized_obj
|
||||
buf[1 : len(serialized_obj) + 1] = serialized_obj
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(serialized_obj)
|
||||
|
||||
def dequeue(self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None,
|
||||
indefinite: bool = False):
|
||||
""" Read from message queue with optional timeout (in seconds) """
|
||||
def dequeue(
|
||||
self,
|
||||
timeout: Optional[float] = None,
|
||||
cancel: Optional[Event] = None,
|
||||
indefinite: bool = False,
|
||||
):
|
||||
"""Read from message queue with optional timeout (in seconds)"""
|
||||
if self._is_local_reader:
|
||||
with self.acquire_read(timeout, cancel, indefinite) as buf:
|
||||
overflow = buf[0] == 1
|
||||
@@ -539,11 +557,12 @@ class MessageQueue:
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(pg: Union[ProcessGroup,
|
||||
StatelessProcessGroup],
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0) -> "MessageQueue":
|
||||
def create_from_process_group(
|
||||
pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0,
|
||||
) -> "MessageQueue":
|
||||
if isinstance(pg, ProcessGroup):
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
@@ -554,6 +573,7 @@ class MessageQueue:
|
||||
global_ranks = list(range(pg.world_size))
|
||||
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
@@ -570,17 +590,17 @@ class MessageQueue:
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
if isinstance(pg, ProcessGroup):
|
||||
dist.broadcast_object_list([handle],
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
dist.broadcast_object_list(
|
||||
[handle], src=global_ranks[writer_rank], group=pg
|
||||
)
|
||||
else:
|
||||
pg.broadcast_obj(handle, writer_rank)
|
||||
else:
|
||||
if isinstance(pg, ProcessGroup):
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv,
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
dist.broadcast_object_list(
|
||||
recv, src=global_ranks[writer_rank], group=pg
|
||||
)
|
||||
handle = recv[0] # type: ignore
|
||||
else:
|
||||
handle = pg.broadcast_obj(None, writer_rank)
|
||||
|
||||
@@ -24,63 +24,63 @@ class SingleWriterShmRingBuffer:
|
||||
A single-writer, multiple-reader ring buffer implementation using shared
|
||||
memory. This class provides a thread-safe ring buffer where one process
|
||||
can write data while multiple processes/threads can read from it.
|
||||
|
||||
|
||||
Architecture:
|
||||
- Uses shared memory for cross-process communication
|
||||
- Maintains metadata for each allocated buffer chunk in the writer process
|
||||
- Supports custom "is_free_fn" functions to determine when buffers can be
|
||||
reused
|
||||
- Each buffer chunk contains: `[4-byte id][4-byte size][actual_data]`
|
||||
|
||||
|
||||
Key Concepts:
|
||||
- monotonic_id_start/end: Track the range of active buffer IDs
|
||||
- data_buffer_start/end: Track the physical memory range in use
|
||||
- Automatic wraparound when reaching buffer end
|
||||
- Lazy garbage collection based on is_free_fn checks
|
||||
|
||||
|
||||
Example Usage Scenarios:
|
||||
|
||||
|
||||
Scenario 1: Simple Linear Allocation
|
||||
```
|
||||
Buffer size: 100 bytes
|
||||
Initial state: [................................................. ]
|
||||
^start=end(0)
|
||||
|
||||
|
||||
After allocating 20 bytes (id=0):
|
||||
[id:0|size:20|data........][...................................]
|
||||
^start(0) ^end(28)
|
||||
|
||||
After allocating 30 bytes (id=1):
|
||||
|
||||
After allocating 30 bytes (id=1):
|
||||
[id:0|size:20|data........][id:1|size:30|data..............][..]
|
||||
^start(0) ^end(66)
|
||||
```
|
||||
|
||||
|
||||
Scenario 2: Memory Reclamation
|
||||
```
|
||||
Before freeing (both buffers still in use):
|
||||
[id:0|size:20|data........][id:1|size:30|data..............][..]
|
||||
^start(0) ^end(66)
|
||||
|
||||
|
||||
After id:0 is marked free by readers:
|
||||
[FREED.................... ][id:1|size:30|data..............][..]
|
||||
^start(28) ^end(66)
|
||||
|
||||
|
||||
After both are freed:
|
||||
[FREED..............................................][..]
|
||||
^start=end(66)
|
||||
```
|
||||
|
||||
|
||||
Scenario 3: Wraparound Allocation (continuing from Scenario 2)
|
||||
```
|
||||
Starting from after memory reclamation in Scenario 2:
|
||||
[FREED..............................................][..]
|
||||
^start=end(66)
|
||||
|
||||
|
||||
Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound:
|
||||
[id:2|size:40|data........................][FREED.............][..]
|
||||
^end(148) ^start(66)
|
||||
```
|
||||
|
||||
|
||||
Scenario 4: Error Handling - Out of Space
|
||||
```
|
||||
Starting from after wraparound allocation in Scenario 3:
|
||||
@@ -91,17 +91,17 @@ class SingleWriterShmRingBuffer:
|
||||
occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100)
|
||||
-> Raises MemoryError: "Not enough space in the data buffer"
|
||||
```
|
||||
|
||||
|
||||
Thread Safety:
|
||||
- Single writer: Only one process/thread should write (allocate_buf)
|
||||
- Multiple readers: Multiple processes/threads can read (access_buf)
|
||||
- Multiple readers: Multiple processes/threads can read (access_buf)
|
||||
- Reader synchronization handled by is_free_fn callback
|
||||
- Writer handles garbage collection (free_buf) based on reader feedback
|
||||
|
||||
|
||||
Memory Layout per Buffer Chunk:
|
||||
`[4-byte monotonic_id][4-byte chunk_size][actual_data...]`
|
||||
^metadata_start ^data_start
|
||||
|
||||
|
||||
The monotonic_id ensures data integrity - readers can verify they're
|
||||
accessing the correct data even after buffer wraparound or reuse.
|
||||
"""
|
||||
@@ -131,15 +131,16 @@ class SingleWriterShmRingBuffer:
|
||||
self.monotonic_id_end: self.data_buffer_end
|
||||
} # monotonic_id -> start address
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=self.data_buffer_size, name=name)
|
||||
create=True, size=self.data_buffer_size, name=name
|
||||
)
|
||||
else:
|
||||
# we are opening an existing buffer
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch(
|
||||
"multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None,
|
||||
"multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None,
|
||||
):
|
||||
self.shared_memory = shared_memory.SharedMemory(name=name)
|
||||
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
|
||||
@@ -149,8 +150,11 @@ class SingleWriterShmRingBuffer:
|
||||
# when attaching to an existing block.
|
||||
assert self.shared_memory.size >= self.data_buffer_size
|
||||
|
||||
logger.debug("Shared memory created/opened with name: %s, size: %d",
|
||||
self.shared_memory.name, self.data_buffer_size)
|
||||
logger.debug(
|
||||
"Shared memory created/opened with name: %s, size: %d",
|
||||
self.shared_memory.name,
|
||||
self.data_buffer_size,
|
||||
)
|
||||
|
||||
def handle(self):
|
||||
return (
|
||||
@@ -182,19 +186,20 @@ class SingleWriterShmRingBuffer:
|
||||
return int.from_bytes(byte_data, "little", signed=True)
|
||||
|
||||
def allocate_buf(self, size: int) -> tuple[int, int]:
|
||||
'''
|
||||
"""
|
||||
Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory.
|
||||
Memory layout:
|
||||
`[4-byte monotonic_id][4-byte size][buffer data...]`
|
||||
'''
|
||||
"""
|
||||
assert self.is_writer, "Only the writer can allocate buffers."
|
||||
assert size > 0, "Size must be greater than 0"
|
||||
size += self.MD_SIZE # add metadata size to the buffer size
|
||||
# reset to beginning if the buffer does have enough contiguous space
|
||||
buffer_end_reset = self.data_buffer_end % self.data_buffer_size
|
||||
if buffer_end_reset + size > self.data_buffer_size:
|
||||
buffer_end_reset = (self.data_buffer_end // self.data_buffer_size +
|
||||
1) * self.data_buffer_size
|
||||
buffer_end_reset = (
|
||||
self.data_buffer_end // self.data_buffer_size + 1
|
||||
) * self.data_buffer_size
|
||||
else: # no reset needed
|
||||
buffer_end_reset = self.data_buffer_end
|
||||
|
||||
@@ -203,21 +208,24 @@ class SingleWriterShmRingBuffer:
|
||||
# exceeds the start of the data buffer
|
||||
occupied_size_new = buffer_end_reset + size - self.data_buffer_start
|
||||
if occupied_size_new > self.data_buffer_size:
|
||||
raise MemoryError("Not enough space in the data buffer, "
|
||||
"try calling free_buf() to free up space")
|
||||
raise MemoryError(
|
||||
"Not enough space in the data buffer, "
|
||||
"try calling free_buf() to free up space"
|
||||
)
|
||||
self.data_buffer_end = buffer_end_reset
|
||||
|
||||
# first 4 bytes as the monotonic id
|
||||
buf_idx = self.data_buffer_end % self.data_buffer_size
|
||||
self.shared_memory.buf[buf_idx:buf_idx + self.ID_NBYTES] = \
|
||||
self.int2byte(self.monotonic_id_end)
|
||||
self.shared_memory.buf[buf_idx : buf_idx + self.ID_NBYTES] = self.int2byte(
|
||||
self.monotonic_id_end
|
||||
)
|
||||
# next 4 bytes as the size of the data buffer
|
||||
self.shared_memory.buf[buf_idx + self.ID_NBYTES: \
|
||||
buf_idx + self.MD_SIZE] = self.int2byte(size)
|
||||
self.shared_memory.buf[buf_idx + self.ID_NBYTES : buf_idx + self.MD_SIZE] = (
|
||||
self.int2byte(size)
|
||||
)
|
||||
|
||||
# record metadata
|
||||
self.metadata[self.monotonic_id_end %
|
||||
self.ID_MAX] = self.data_buffer_end
|
||||
self.metadata[self.monotonic_id_end % self.ID_MAX] = self.data_buffer_end
|
||||
# update buffer and monotonic id indices
|
||||
current_buffer_end = self.data_buffer_end
|
||||
current_id_end = self.monotonic_id_end
|
||||
@@ -230,23 +238,26 @@ class SingleWriterShmRingBuffer:
|
||||
buf_idx = address % self.data_buffer_size
|
||||
|
||||
# read metadata
|
||||
metadata_buff = self.shared_memory.buf[buf_idx:buf_idx + self.MD_SIZE]
|
||||
id = self.byte2int(metadata_buff[:self.ID_NBYTES])
|
||||
size = self.byte2int(metadata_buff[self.ID_NBYTES:self.MD_SIZE])
|
||||
metadata_buff = self.shared_memory.buf[buf_idx : buf_idx + self.MD_SIZE]
|
||||
id = self.byte2int(metadata_buff[: self.ID_NBYTES])
|
||||
size = self.byte2int(metadata_buff[self.ID_NBYTES : self.MD_SIZE])
|
||||
|
||||
# yield the data buffer and metadata
|
||||
data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE:buf_idx +
|
||||
size]
|
||||
with (memoryview(data_buff) as data_view, ):
|
||||
data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE : buf_idx + size]
|
||||
with (
|
||||
memoryview(data_buff) as data_view,
|
||||
):
|
||||
yield data_view, (id, size)
|
||||
|
||||
def free_buf(self,
|
||||
is_free_fn: Callable[[int, memoryview], bool],
|
||||
nbytes: Optional[int] = None) -> Iterable[int]:
|
||||
'''
|
||||
def free_buf(
|
||||
self,
|
||||
is_free_fn: Callable[[int, memoryview], bool],
|
||||
nbytes: Optional[int] = None,
|
||||
) -> Iterable[int]:
|
||||
"""
|
||||
Free a buffer of the given size. This is a no-op in shared memory,
|
||||
but we need to keep track of the metadata.
|
||||
|
||||
|
||||
If freed memory spreads across the end and start of the ring buffer,
|
||||
the actual freed memory will be in two segments. In this case there
|
||||
still might not be a contiguous space of `nbytes` available.
|
||||
@@ -254,13 +265,15 @@ class SingleWriterShmRingBuffer:
|
||||
Args:
|
||||
nbytes (int, optional): The size of the buffer to free. If None,
|
||||
frees the maximum size of the ring buffer.
|
||||
'''
|
||||
"""
|
||||
|
||||
assert self.is_writer, "Only the writer can free buffers."
|
||||
logger.debug(
|
||||
"Freeing up space in the ring buffer, "
|
||||
"monotonic_id_start: %d, monotonic_id_end: %d",
|
||||
self.monotonic_id_start, self.monotonic_id_end)
|
||||
self.monotonic_id_start,
|
||||
self.monotonic_id_end,
|
||||
)
|
||||
monotonic_id_before = self.monotonic_id_start
|
||||
# if nbytes is None, free up the maximum size of the ring buffer
|
||||
if nbytes is None:
|
||||
@@ -272,8 +285,9 @@ class SingleWriterShmRingBuffer:
|
||||
if is_free_fn(self.monotonic_id_start, data_buff):
|
||||
# check passed, we can free the buffer
|
||||
del self.metadata[self.monotonic_id_start]
|
||||
self.monotonic_id_start = ((self.monotonic_id_start + 1) %
|
||||
self.ID_MAX)
|
||||
self.monotonic_id_start = (
|
||||
self.monotonic_id_start + 1
|
||||
) % self.ID_MAX
|
||||
self.data_buffer_start = address
|
||||
freed_bytes += metadata[1]
|
||||
else:
|
||||
@@ -282,8 +296,11 @@ class SingleWriterShmRingBuffer:
|
||||
|
||||
logger.debug(
|
||||
"Freed %d bytes from the ring buffer, "
|
||||
"monotonic_id_start: %d, monotonic_id_end: %d", freed_bytes,
|
||||
self.monotonic_id_start, self.monotonic_id_end)
|
||||
"monotonic_id_start: %d, monotonic_id_end: %d",
|
||||
freed_bytes,
|
||||
self.monotonic_id_start,
|
||||
self.monotonic_id_end,
|
||||
)
|
||||
|
||||
# buffer wrap around
|
||||
if self.data_buffer_start >= self.data_buffer_size:
|
||||
@@ -295,12 +312,12 @@ class SingleWriterShmRingBuffer:
|
||||
if monotonic_id_after >= monotonic_id_before:
|
||||
return range(monotonic_id_before, monotonic_id_after)
|
||||
else:
|
||||
return chain(range(monotonic_id_before, self.ID_MAX),
|
||||
range(0, monotonic_id_after))
|
||||
return chain(
|
||||
range(monotonic_id_before, self.ID_MAX), range(0, monotonic_id_after)
|
||||
)
|
||||
|
||||
|
||||
class ObjectSerde(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def serialize(self, value: Any) -> tuple[Any, int, bytes, int]:
|
||||
"""Serialize an object to bytes."""
|
||||
@@ -313,7 +330,6 @@ class ObjectSerde(ABC):
|
||||
|
||||
|
||||
class MsgpackSerde(ObjectSerde):
|
||||
|
||||
def __init__(self):
|
||||
# Delayed import to avoid circular dependency
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||
@@ -325,8 +341,8 @@ class MsgpackSerde(ObjectSerde):
|
||||
self._mm_kwargs_item_cls = MultiModalKwargsItem
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
value: Any) -> tuple[Union[bytes, list[bytes]], int, bytes, int]:
|
||||
self, value: Any
|
||||
) -> tuple[Union[bytes, list[bytes]], int, bytes, int]:
|
||||
len_arr = None
|
||||
if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)):
|
||||
type_name = type(value).__name__
|
||||
@@ -339,8 +355,9 @@ class MsgpackSerde(ObjectSerde):
|
||||
nbytes = len(value)
|
||||
|
||||
object_metadata = (type_name, nbytes, len_arr)
|
||||
serialized_metadata = pickle.dumps(object_metadata,
|
||||
protocol=pickle.HIGHEST_PROTOCOL)
|
||||
serialized_metadata = pickle.dumps(
|
||||
object_metadata, protocol=pickle.HIGHEST_PROTOCOL
|
||||
)
|
||||
return value, nbytes, serialized_metadata, len(serialized_metadata)
|
||||
|
||||
def deserialize(self, data_view: memoryview) -> Any:
|
||||
@@ -353,7 +370,7 @@ class MsgpackSerde(ObjectSerde):
|
||||
obj = []
|
||||
start_idx = 0
|
||||
for length in len_arr:
|
||||
item_bytes = serialized_data[start_idx:start_idx + length]
|
||||
item_bytes = serialized_data[start_idx : start_idx + length]
|
||||
obj.append(item_bytes)
|
||||
start_idx += length
|
||||
obj = self.tensor_decoder.decode(obj)
|
||||
@@ -361,15 +378,14 @@ class MsgpackSerde(ObjectSerde):
|
||||
obj = []
|
||||
start_idx = 0
|
||||
for length in len_arr:
|
||||
item_bytes = serialized_data[start_idx:start_idx + length]
|
||||
item_bytes = serialized_data[start_idx : start_idx + length]
|
||||
obj.append(item_bytes)
|
||||
start_idx += length
|
||||
obj = self.mm_decoder.decode(obj)
|
||||
elif type_name == bytes.__name__:
|
||||
obj = pickle.loads(serialized_data)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported object type '{type_name}' in metadata")
|
||||
raise ValueError(f"Unsupported object type '{type_name}' in metadata")
|
||||
|
||||
return obj
|
||||
|
||||
@@ -388,18 +404,18 @@ class SingleWriterShmObjectStorage:
|
||||
A single-writer, multiple-reader object storage system built on top of a
|
||||
shared memory ring buffer. Provides key-value storage with automatic memory
|
||||
management and cross-process serialization support.
|
||||
|
||||
|
||||
This storage system follows a FIFO (First-In-First-Out) eviction policy
|
||||
where the oldest objects are automatically freed when memory runs low.
|
||||
Memory is reclaimed based on reader reference counting - objects are only
|
||||
freed when all readers have finished accessing them.
|
||||
|
||||
|
||||
Architecture:
|
||||
- Single writer process can put(key, value) objects
|
||||
- Multiple reader processes can get(address, monotonic_id) objects
|
||||
- Built on SingleWriterShmRingBuffer for efficient shared memory management
|
||||
- Thread-safe operations with reader synchronization via locks
|
||||
|
||||
|
||||
Key Features:
|
||||
- FIFO Eviction: Oldest objects are evicted first when memory is full
|
||||
- Reference Counting: Objects are only freed when no readers are
|
||||
@@ -414,7 +430,7 @@ class SingleWriterShmObjectStorage:
|
||||
|
||||
Memory Layout per Object:
|
||||
`[4-byte reference_count][metadata_size][serialized_object_data]`
|
||||
|
||||
|
||||
Thread Safety:
|
||||
- Writer operations (put, clear) are single-threaded by design
|
||||
- Reader operations (get) are thread-safe with lock-based reference
|
||||
@@ -482,18 +498,17 @@ class SingleWriterShmObjectStorage:
|
||||
md_bytes: int,
|
||||
data_view: memoryview,
|
||||
) -> None:
|
||||
data_view[self.flag_bytes:self.flag_bytes + md_bytes] = metadata
|
||||
data_view[self.flag_bytes : self.flag_bytes + md_bytes] = metadata
|
||||
if isinstance(data, bytes):
|
||||
data_view[-data_bytes:] = data
|
||||
elif isinstance(data, list):
|
||||
start_idx = self.flag_bytes + md_bytes
|
||||
for item_bytes in data:
|
||||
item_size = len(item_bytes)
|
||||
data_view[start_idx:start_idx + item_size] = item_bytes
|
||||
data_view[start_idx : start_idx + item_size] = item_bytes
|
||||
start_idx += item_size
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported data type for serialization: {type(data)}")
|
||||
raise ValueError(f"Unsupported data type for serialization: {type(data)}")
|
||||
|
||||
def increment_writer_flag(self, id: int) -> None:
|
||||
"""Set the in-use flag for the writer."""
|
||||
@@ -509,8 +524,9 @@ class SingleWriterShmObjectStorage:
|
||||
"""Free unused buffers in the ring buffer."""
|
||||
# try to free up 2*max_object_size bytes of space in the ring buffer,
|
||||
# since the buffer might be fragmented
|
||||
freed_ids = self.ring_buffer.free_buf(self.default_is_free_check,
|
||||
2 * self.max_object_size)
|
||||
freed_ids = self.ring_buffer.free_buf(
|
||||
self.default_is_free_check, 2 * self.max_object_size
|
||||
)
|
||||
# update the metadata after freeing up space
|
||||
for freed_id in freed_ids:
|
||||
key_to_free = self.id_index[freed_id]
|
||||
@@ -537,7 +553,7 @@ class SingleWriterShmObjectStorage:
|
||||
Store a key-value pair in the object storage.
|
||||
Attempts to free max_object_size bytes using FIFO order
|
||||
when the ring buffer runs out of space during a put() operation.
|
||||
|
||||
|
||||
Args:
|
||||
key: String key to identify the object
|
||||
value: Any serializable Python object
|
||||
@@ -550,15 +566,17 @@ class SingleWriterShmObjectStorage:
|
||||
if key in self.key_index:
|
||||
raise ValueError(f"Key '{key}' already exists in the storage.")
|
||||
|
||||
object_data, data_bytes, object_metadata, md_bytes = \
|
||||
self.ser_de.serialize(value)
|
||||
object_data, data_bytes, object_metadata, md_bytes = self.ser_de.serialize(
|
||||
value
|
||||
)
|
||||
buffer_size = self.flag_bytes + data_bytes + md_bytes
|
||||
|
||||
# Sanity checks
|
||||
if buffer_size > self.max_object_size:
|
||||
raise ValueError(
|
||||
f"Serialized object size ({buffer_size} bytes) exceeds "
|
||||
f"max object size ({self.max_object_size} bytes)")
|
||||
f"max object size ({self.max_object_size} bytes)"
|
||||
)
|
||||
|
||||
# Allocate new buffer
|
||||
try:
|
||||
@@ -570,9 +588,10 @@ class SingleWriterShmObjectStorage:
|
||||
|
||||
# Write data to buffer
|
||||
with self.ring_buffer.access_buf(address) as (data_view, metadata):
|
||||
data_view[:self.flag_bytes] = self.ring_buffer.int2byte(0)
|
||||
self.copy_to_buffer(object_data, data_bytes, object_metadata,
|
||||
md_bytes, data_view)
|
||||
data_view[: self.flag_bytes] = self.ring_buffer.int2byte(0)
|
||||
self.copy_to_buffer(
|
||||
object_data, data_bytes, object_metadata, md_bytes, data_view
|
||||
)
|
||||
self.increment_writer_flag(monotonic_id)
|
||||
|
||||
# Update key index
|
||||
@@ -587,14 +606,15 @@ class SingleWriterShmObjectStorage:
|
||||
if buf_metadata[0] != monotonic_id:
|
||||
raise ValueError(
|
||||
f"Data for address:id '{address}:{monotonic_id}'"
|
||||
" has been modified or is invalid.")
|
||||
" has been modified or is invalid."
|
||||
)
|
||||
|
||||
obj = self.ser_de.deserialize(data_view[self.flag_bytes:])
|
||||
obj = self.ser_de.deserialize(data_view[self.flag_bytes :])
|
||||
|
||||
# decrease the in-use flag for reader reads
|
||||
if self._reader_lock is not None:
|
||||
with self._reader_lock:
|
||||
self.increment_reader_flag(data_view[:self.flag_bytes])
|
||||
self.increment_reader_flag(data_view[: self.flag_bytes])
|
||||
else:
|
||||
# if self._reader_lock is None, it means we are the writer
|
||||
# in this case, we do not need to decrease the reader count
|
||||
@@ -614,7 +634,8 @@ class SingleWriterShmObjectStorage:
|
||||
|
||||
@staticmethod
|
||||
def create_from_handle(
|
||||
handle: ShmObjectStorageHandle) -> "SingleWriterShmObjectStorage":
|
||||
handle: ShmObjectStorageHandle,
|
||||
) -> "SingleWriterShmObjectStorage":
|
||||
logger.debug("Creating storage from handle: %s", handle)
|
||||
ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle)
|
||||
return SingleWriterShmObjectStorage(
|
||||
|
||||
@@ -7,7 +7,8 @@ import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -28,20 +29,20 @@ class SymmMemCommunicator:
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
# add options for testing
|
||||
force_multimem: Optional[bool] = None,
|
||||
max_size_override: Optional[int] = None):
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
# add options for testing
|
||||
force_multimem: Optional[bool] = None,
|
||||
max_size_override: Optional[int] = None,
|
||||
):
|
||||
self.disabled = True
|
||||
|
||||
if not symm_mem_available:
|
||||
return
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
logger.warning("SymmMemCommunicator: symmetric "
|
||||
"memory is not available.")
|
||||
logger.warning("SymmMemCommunicator: symmetric memory is not available.")
|
||||
return
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
@@ -52,8 +53,9 @@ class SymmMemCommunicator:
|
||||
self.device = device
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
self.device_capability = current_platform.get_device_capability(
|
||||
).as_version_str()
|
||||
self.device_capability = (
|
||||
current_platform.get_device_capability().as_version_str()
|
||||
)
|
||||
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: Device capability %s not supported, "
|
||||
@@ -61,8 +63,7 @@ class SymmMemCommunicator:
|
||||
self.device_capability,
|
||||
)
|
||||
return
|
||||
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
|
||||
self.device_capability]:
|
||||
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: World size %d not supported, "
|
||||
"communicator is not available.",
|
||||
@@ -77,8 +78,9 @@ class SymmMemCommunicator:
|
||||
self.max_size,
|
||||
)
|
||||
else:
|
||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
|
||||
self.device_capability][self.world_size]
|
||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
||||
self.world_size
|
||||
]
|
||||
|
||||
self.buffer = torch_symm_mem.empty(
|
||||
self.max_size // self.dtype.itemsize,
|
||||
@@ -87,8 +89,10 @@ class SymmMemCommunicator:
|
||||
)
|
||||
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
||||
if handle.multicast_ptr == 0:
|
||||
logger.warning("SymmMemCommunicator: symmetric memory "
|
||||
"multicast operations are not supported.")
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: symmetric memory "
|
||||
"multicast operations are not supported."
|
||||
)
|
||||
return
|
||||
self.force_multimem = force_multimem
|
||||
self.disabled = False
|
||||
@@ -104,15 +108,13 @@ class SymmMemCommunicator:
|
||||
return inp_size < self.max_size
|
||||
|
||||
def all_reduce(
|
||||
self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
|
||||
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
|
||||
) -> Optional[torch.Tensor]:
|
||||
if not self.should_use_symm_mem(inp):
|
||||
return None
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
self.buffer[:inp.numel()].copy_(inp.view(-1))
|
||||
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
||||
|
||||
# Determine which algorithm to use
|
||||
use_multimem = False
|
||||
@@ -121,16 +123,17 @@ class SymmMemCommunicator:
|
||||
use_multimem = self.force_multimem
|
||||
else:
|
||||
# Normal logic: use multimem for supported world sizes
|
||||
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[
|
||||
self.device_capability]
|
||||
use_multimem = (
|
||||
self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]
|
||||
)
|
||||
|
||||
if use_multimem:
|
||||
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
|
||||
"sum",
|
||||
self.group.group_name)
|
||||
torch.ops.symm_mem.multimem_all_reduce_(
|
||||
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||
)
|
||||
else:
|
||||
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
|
||||
"sum",
|
||||
self.group.group_name)
|
||||
out.copy_(self.buffer[:inp.numel()].view(out.shape))
|
||||
torch.ops.symm_mem.two_shot_all_reduce_(
|
||||
self.buffer[: inp.numel()], "sum", self.group.group_name
|
||||
)
|
||||
out.copy_(self.buffer[: inp.numel()].view(out.shape))
|
||||
return out
|
||||
|
||||
@@ -14,8 +14,9 @@ from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
USE_RAY = parallel_config = get_current_vllm_config(
|
||||
).parallel_config.distributed_executor_backend == "ray"
|
||||
USE_RAY = parallel_config = (
|
||||
get_current_vllm_config().parallel_config.distributed_executor_backend == "ray"
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -27,18 +28,21 @@ if not USE_TPU_COMMONS:
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
from torch_xla.distributed.xla_multiprocessing import (
|
||||
create_optimized_replica_groups)
|
||||
create_optimized_replica_groups,
|
||||
)
|
||||
|
||||
if USE_RAY:
|
||||
from vllm.executor import ray_utils
|
||||
|
||||
|
||||
class TpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
|
||||
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
|
||||
@@ -98,5 +102,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
if USE_TPU_COMMONS:
|
||||
from tpu_commons.distributed.device_communicators import (
|
||||
TpuCommunicator as TpuCommonsCommunicator)
|
||||
TpuCommunicator as TpuCommonsCommunicator,
|
||||
)
|
||||
|
||||
TpuCommunicator = TpuCommonsCommunicator # type: ignore
|
||||
|
||||
@@ -16,12 +16,13 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
@@ -29,10 +30,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
logger.warning(
|
||||
"`%s` all2all manager is not supported on XPU."
|
||||
"Falling back to `naive` all2all manager for XPU.",
|
||||
all2all_backend)
|
||||
all2all_backend,
|
||||
)
|
||||
all2all_backend = "naive"
|
||||
if all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
@@ -40,12 +43,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
@@ -53,20 +56,19 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
# cluster so we use all_gather instead for now.
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((self.world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
output_tensor = torch.empty(
|
||||
(self.world_size,) + input_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim]
|
||||
+ (self.world_size * input_size[dim],)
|
||||
+ input_size[dim + 1 :]
|
||||
)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
@@ -78,17 +80,19 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits, is_sequence_parallel)
|
||||
hidden_states, router_logits, is_sequence_parallel
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
hidden_states = self.all2all_manager.combine(
|
||||
hidden_states, is_sequence_parallel
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user