Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -33,17 +33,19 @@ class NaiveAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
def naive_multicast(
self,
x: torch.Tensor,
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool,
) -> torch.Tensor:
assert len(x.shape) == 2
buffer = torch.empty(
(cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype
)
rank = self.rank if is_sequence_parallel else self.dp_rank
world_size = (self.world_size
if is_sequence_parallel else self.dp_world_size)
world_size = self.world_size if is_sequence_parallel else self.dp_world_size
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
@@ -59,24 +61,23 @@ class NaiveAll2AllManager(All2AllManagerBase):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata
@@ -107,13 +108,12 @@ class AgRsAll2AllManager(All2AllManagerBase):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
@@ -124,19 +124,16 @@ class AgRsAll2AllManager(All2AllManagerBase):
)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
return hidden_states
def destroy(self):
@@ -149,24 +146,35 @@ class PPLXAll2AllManager(All2AllManagerBase):
"""
def __init__(self, cpu_group):
assert has_pplx(
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
assert has_pplx(), (
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels."
) # noqa
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
from pplx_kernels.nvshmem import (
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: "
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
self.rank,
self.world_size,
)
uid = (
nvshmem_get_unique_id()
if self.rank == 0
else nvshmem_alloc_empty_unique_id()
)
dist.broadcast(
uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group,
)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
@@ -174,21 +182,23 @@ class PPLXAll2AllManager(All2AllManagerBase):
def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
kwargs,
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
@@ -198,6 +208,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
@@ -208,8 +219,9 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
def __init__(self, cpu_group):
assert has_deep_ep(
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels."
) # noqa
super().__init__(cpu_group)
self.handle_cache = Cache()
@@ -224,13 +236,13 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
@@ -260,23 +272,27 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank)
return dict(
group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
)
def get_handle(self, kwargs):
assert len(kwargs) == 0, (
"DeepEPHTAll2AllManager expects no arguments. All the required "
"args are computed in the Manager itself.")
"args are computed in the Manager itself."
)
import deep_ep
buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
buffer_kwargs, deep_ep.Buffer
)
return handle
def set_num_sms(self, num_sms: int):
@@ -323,14 +339,17 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=num_ep_ranks,
num_experts=num_global_experts)
num_experts=num_global_experts,
)
assert num_rdma_bytes is not None
return dict(group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank)
return dict(
group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
)
def get_handle(self, kwargs):
"""
@@ -338,10 +357,12 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
_make_all2all_kwargs.
"""
import deep_ep
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer)
buffer_kwargs, deep_ep.Buffer
)
return handle
# DeepEP LL uses RDMA so no SMs are used for communication
@@ -355,12 +376,15 @@ class FlashInferAllToAllManager(All2AllManagerBase):
"""
def __init__(self, cpu_group):
assert has_flashinfer_all2all(
), "flashinfer all2all module not found. Please install/check flashinfer" # noqa
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group)
logger.debug(
"Initialize for flashinfer All2All "
"rank=%d, world size=%d", self.rank, self.world_size)
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,
self.world_size,
)
self.initialized = False
self.alltoall_info = None
@@ -375,8 +399,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
return
self.cleanup()
logger.debug("making map: "
"rank=%d, world size=%d", rank, world_size)
logger.debug("making map: rank=%d, world size=%d", rank, world_size)
self.mapping = Mapping(
world_size,
rank,
@@ -385,25 +408,28 @@ class FlashInferAllToAllManager(All2AllManagerBase):
)
from vllm.distributed.device_communicators.mnnvl_compat import (
CustomCommunicator)
CustomCommunicator,
)
dp_config = MnnvlConfig(
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
fabric_page_size=1 << 29, # 512MB
allocation_granularity=0 # Auto-detect
allocation_granularity=0, # Auto-detect
)
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
self.mapping, dp_config)
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config)
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
self.mapping, dp_config)
self.mapping, dp_config
)
self.world_size = world_size
self.rank = rank
self.gpus_per_node = gpus_per_node
self.initialized = True
logger.info("FlashInfer All2All initialized for rank %s, size %s",
rank, world_size)
logger.info(
"FlashInfer All2All initialized for rank %s, size %s", rank, world_size
)
def ensure_alltoall_workspace_initialized(self):
"""Ensure workspace is initialized"""
@@ -426,8 +452,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.workspace_tensor is not None \
and self.prepare_workspace_tensor is not None:
if (
self.initialized
and self.workspace_tensor is not None
and self.prepare_workspace_tensor is not None
):
try:
del self.workspace_tensor
del self.prepare_workspace_tensor

View File

@@ -19,8 +19,7 @@ import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)
from vllm.utils import cuda_device_count_stateless, update_environment_variables
logger = init_logger(__name__)
@@ -39,7 +38,7 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = {
4: 2 * MiB, # 2 MB
6: 1 * MiB, # 1 MB
8: 1 * MiB, # 1 MB
}
},
}
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
@@ -54,7 +53,7 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
},
}
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
@@ -63,14 +62,15 @@ NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
4: 2 * MiB, # 2 MB
8: 1 * MiB, # 1 MB
},
"always_use_above_world_size": 8 # Always use symm mem for world_size > 8
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
}
def should_nccl_symm_mem_allreduce(world_size: int,
input_tensor: torch.Tensor) -> bool:
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled)
is_symmetric_memory_enabled,
)
if not is_symmetric_memory_enabled():
return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
@@ -78,18 +78,18 @@ def should_nccl_symm_mem_allreduce(world_size: int,
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
if threshold is not None and input_tensor.nbytes >= threshold:
return True
return (world_size
> NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"])
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]
def producer(batch_src: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None):
def producer(
batch_src: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
):
if cuda_visible_devices is not None:
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for i in batch_src:
@@ -115,14 +115,15 @@ def producer(batch_src: Sequence[int],
lib.cudaDeviceReset()
def consumer(batch_tgt: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None):
def consumer(
batch_tgt: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
):
if cuda_visible_devices is not None:
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for j in batch_tgt:
@@ -198,12 +199,26 @@ def can_actually_p2p(
producer_queue = smp.Queue()
consumer_queue = smp.Queue()
result_queue = smp.Queue()
p_src = smp.Process(target=producer,
args=(batch_src, producer_queue, consumer_queue,
result_queue, cuda_visible_devices))
p_tgt = smp.Process(target=consumer,
args=(batch_tgt, producer_queue, consumer_queue,
result_queue, cuda_visible_devices))
p_src = smp.Process(
target=producer,
args=(
batch_src,
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices,
),
)
p_tgt = smp.Process(
target=consumer,
args=(
batch_tgt,
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices,
),
)
p_src.start()
p_tgt.start()
p_src.join()
@@ -216,7 +231,10 @@ def can_actually_p2p(
if a != b:
logger.warning(
"Two processes do not agree on the P2P access"
" status on %d -> %d, treat as disabled.", src, tgt)
" status on %d -> %d, treat as disabled.",
src,
tgt,
)
result.append(False)
else:
result.append(a)
@@ -255,12 +273,14 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
path = os.path.join(
envs.VLLM_CACHE_ROOT,
f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
)
os.makedirs(os.path.dirname(path), exist_ok=True)
from vllm.distributed.parallel_state import get_world_group
if ((not is_distributed or get_world_group().local_rank == 0)
and (not os.path.exists(path))):
if (not is_distributed or get_world_group().local_rank == 0) and (
not os.path.exists(path)
):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path)
@@ -279,11 +299,10 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with tempfile.NamedTemporaryFile() as output_file:
input_bytes = pickle.dumps(
(batch_src, batch_tgt, output_file.name))
returned = subprocess.run([sys.executable, __file__],
input=input_bytes,
capture_output=True)
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
returned = subprocess.run(
[sys.executable, __file__], input=input_bytes, capture_output=True
)
# check if the subprocess is successful
try:
returned.check_returncode()
@@ -292,7 +311,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e
f"{returned.stderr.decode()}"
) from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
for _i, _j, r in zip(batch_src, batch_tgt, result):

View File

@@ -10,7 +10,6 @@ from torch.distributed import ProcessGroup
class Cache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
@@ -35,9 +34,11 @@ class All2AllManagerBase:
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
in_the_same_node_as)
from vllm.distributed.parallel_state import (
get_dp_group,
get_tp_group,
in_the_same_node_as,
)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
@@ -63,10 +64,12 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
):
raise NotImplementedError
def set_num_sms(self, num_sms: int):
@@ -75,9 +78,7 @@ class All2AllManagerBase:
def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False):
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False):
raise NotImplementedError
def destroy(self):
@@ -92,11 +93,13 @@ class DeviceCommunicatorBase:
communication backend), the `device_group` will also be given.
"""
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = "",
):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
@@ -106,11 +109,11 @@ class DeviceCommunicatorBase:
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
# as long as we use data parallel (coupled data parallel
@@ -134,41 +137,39 @@ class DeviceCommunicatorBase:
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
output_size = (input_size[0] * self.world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
output_tensor = output_tensor.reshape(
input_size[:dim]
+ (self.world_size * input_size[dim],)
+ input_size[dim + 1 :]
)
return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None
sizes: Optional[list[int]] = None,
) -> Union[torch.Tensor, list[torch.Tensor]]:
raise NotImplementedError
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
)
if dim < 0:
# Convert negative dim to positive.
@@ -180,30 +181,28 @@ class DeviceCommunicatorBase:
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output_shape = (chunk_size,) + input_tensor.shape[1:]
output_tensor = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
output_tensor = torch.empty(
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
# Perform reduce-scatter operation
torch.distributed.reduce_scatter_tensor(output_tensor,
input_tensor,
group=self.device_group)
torch.distributed.reduce_scatter_tensor(
output_tensor, input_tensor, group=self.device_group
)
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
def reduce_scatterv(
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
) -> torch.Tensor:
raise NotImplementedError
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
@@ -211,7 +210,8 @@ class DeviceCommunicatorBase:
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
@@ -222,10 +222,9 @@ class DeviceCommunicatorBase:
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
torch.distributed.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
@@ -239,10 +238,9 @@ class DeviceCommunicatorBase:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
def recv(
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
@@ -255,8 +253,7 @@ class DeviceCommunicatorBase:
def destroy(self):
pass
def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
@@ -264,11 +261,14 @@ class DeviceCommunicatorBase:
return
moe_modules = [
module for module in model.modules()
module
for module in model.modules()
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module)
@@ -277,7 +277,7 @@ class DeviceCommunicatorBase:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
@@ -285,9 +285,9 @@ class DeviceCommunicatorBase:
"""
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.

View File

@@ -15,30 +15,30 @@ from .base_device_communicator import DeviceCommunicatorBase
class CpuCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
self.dist_module = torch.distributed
if (current_platform.get_cpu_architecture()
== CpuArchEnum.X86) and hasattr(
torch.ops._C,
"init_shm_manager") and (unique_name.startswith("tp")
or unique_name.startswith("pp")):
if (
(current_platform.get_cpu_architecture() == CpuArchEnum.X86)
and hasattr(torch.ops._C, "init_shm_manager")
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
):
self.dist_module = _CPUSHMDistributed(self)
def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group)
return input_
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
@@ -46,7 +46,8 @@ class CpuCommunicator(DeviceCommunicatorBase):
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
@@ -58,10 +59,9 @@ class CpuCommunicator(DeviceCommunicatorBase):
gather_list = None
# Gather.
self.dist_module.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
self.dist_module.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
@@ -77,23 +77,24 @@ class CpuCommunicator(DeviceCommunicatorBase):
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
output_size = (input_size[0] * self.world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
self.dist_module.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
self.dist_module.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
output_tensor = output_tensor.reshape(
input_size[:dim]
+ (self.world_size * input_size[dim],)
+ input_size[dim + 1 :]
)
return output_tensor
def send_tensor_dict(
@@ -111,7 +112,6 @@ class CpuCommunicator(DeviceCommunicatorBase):
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
instance_identifier = os.environ["VLLM_DIST_IDENT"]
unique_name = communicator.unique_name
@@ -139,24 +139,32 @@ class _CPUSHMDistributed:
return handle
def all_reduce(self,
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
def all_reduce(
self, input: torch.Tensor, group: Optional[ProcessGroup] = None
) -> None:
torch.ops._C.shm_allreduce(self.handle, input)
def gather(self,
input: torch.Tensor,
gather_list: Optional[list[torch.Tensor]],
dst: int = -1,
group: Optional[ProcessGroup] = None) -> None:
def gather(
self,
input: torch.Tensor,
gather_list: Optional[list[torch.Tensor]],
dst: int = -1,
group: Optional[ProcessGroup] = None,
) -> None:
# Note: different from the torch gather, here we use local dst rank.
torch.ops._C.shm_gather(self.handle, input, gather_list,
torch.distributed.get_group_rank(group, dst))
torch.ops._C.shm_gather(
self.handle,
input,
gather_list,
torch.distributed.get_group_rank(group, dst),
)
def all_gather_into_tensor(self,
output: torch.Tensor,
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
def all_gather_into_tensor(
self,
output: torch.Tensor,
input: torch.Tensor,
group: Optional[ProcessGroup] = None,
) -> None:
torch.ops._C.shm_all_gather(self.handle, input, output)
def send_tensor_dict(
@@ -169,11 +177,11 @@ class _CPUSHMDistributed:
size_list = []
for v in value_list:
if not isinstance(v, torch.Tensor):
raise RuntimeError(
"CpuCommunicator only supports sending tensors.")
raise RuntimeError("CpuCommunicator only supports sending tensors.")
size_list.append(v.size())
key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]),
dtype=torch.uint8)
key_size_tensor = torch.frombuffer(
pickle.dumps([key_list, size_list]), dtype=torch.uint8
)
value_list.append(key_size_tensor)
torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst)

View File

@@ -8,11 +8,12 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.distributed.device_communicators.all_reduce_utils import (
should_nccl_symm_mem_allreduce)
from vllm.distributed.device_communicators.pynccl import (
register_nccl_symmetric_ops)
should_nccl_symm_mem_allreduce,
)
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled)
is_symmetric_memory_enabled,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -22,20 +23,21 @@ logger = init_logger(__name__)
class CudaCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
use_torch_symm_mem = False
else:
from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE)
from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
@@ -44,13 +46,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
CustomAllreduce,
)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)
QuickAllReduce,
)
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if self.world_size > 1:
@@ -75,8 +77,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
symm_mem_enabled=(self.symm_mem_comm is not None
and not self.symm_mem_comm.disabled),
symm_mem_enabled=(
self.symm_mem_comm is not None and not self.symm_mem_comm.disabled
),
)
if current_platform.is_rocm():
@@ -85,35 +88,39 @@ class CudaCommunicator(DeviceCommunicatorBase):
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
# If it's a rocm, 'use_custom_allreduce==True' means it must
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
elif all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AllGather-ReduceScatter all2all manager.")
elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
logger.info("Using PPLX all2all manager.")
elif all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
logger.info("Using DeepEP High-Throughput all2all manager.")
elif all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group)
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
logger.info("Using Flashinfer all2allv manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
@@ -121,28 +128,34 @@ class CudaCommunicator(DeviceCommunicatorBase):
def all_reduce(self, input_):
# since currently we perform copy input -> symm_input -> out-of-place AR
# return symm_output, we don't need to check if input is symmetric
if self.pynccl_comm is not None and \
should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_):
if self.pynccl_comm is not None and should_nccl_symm_mem_allreduce(
self.pynccl_comm.world_size, input_
):
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
if out is not None:
return out
# always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm
if qr_comm is not None and not qr_comm.disabled and \
qr_comm.should_quick_allreduce(input_):
if (
qr_comm is not None
and not qr_comm.disabled
and qr_comm.should_quick_allreduce(input_)
):
out = qr_comm.quick_all_reduce(input_)
assert out is not None
return out
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
if (
ca_comm is not None
and not ca_comm.disabled
and ca_comm.should_custom_ar(input_)
):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
if symm_mem_comm is not None and symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
@@ -176,21 +189,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output_shape = (chunk_size,) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
output = torch.empty(
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None):
def reduce_scatterv(
self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
@@ -209,11 +221,11 @@ class CudaCommunicator(DeviceCommunicatorBase):
else:
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output_shape = (chunk_size,) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
output = torch.empty(
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
@@ -235,10 +247,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
def recv(
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
@@ -261,10 +272,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.all2all_manager.destroy()
self.all2all_manager = None
def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None,
):
if dim != 0:
raise NotImplementedError("only dim 0 all-gatherv is supported")
world_size = self.world_size
@@ -276,20 +289,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
if sizes is not None and all(s == sizes[0] for s in sizes):
sizes = None
def _all_gather_single(input_: torch.Tensor,
sizes: Optional[list[int]] = None):
def _all_gather_single(input_: torch.Tensor, sizes: Optional[list[int]] = None):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[dim] == sizes[self.rank_in_group], (
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}")
output_size = (sum(sizes), ) + input_size[1:]
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}"
)
output_size = (sum(sizes),) + input_size[1:]
else:
output_size = (input_size[0] * world_size, ) + input_size[1:]
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
if sizes is not None:
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
else:
@@ -311,17 +324,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel)
hidden_states, router_logits, is_sequence_parallel
)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
)
return hidden_states

View File

@@ -42,7 +42,7 @@ def find_loaded_library(lib_name) -> Optional[str]:
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
""" # noqa
found = False
with open("/proc/self/maps") as f:
for line in f:
@@ -57,8 +57,9 @@ def find_loaded_library(lib_name) -> Optional[str]:
start = line.index("/")
path = line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), \
assert filename.rpartition(".so")[0].startswith(lib_name), (
f"Unexpected filename: {filename} for library {lib_name}"
)
return path
@@ -70,30 +71,38 @@ class CudaRTLibrary:
Function("cudaDeviceSynchronize", cudaError_t, []),
# cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
Function(
"cudaMalloc",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
),
# cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
Function(
"cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
]),
Function(
"cudaMemcpy",
cudaError_t,
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
Function(
"cudaIpcGetMemHandle",
cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function("cudaIpcOpenMemHandle", cudaError_t, [
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
]),
Function(
"cudaIpcOpenMemHandle",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
),
]
# class attribute to store the mapping from the path to the library
@@ -109,11 +118,10 @@ class CudaRTLibrary:
so_file = find_loaded_library("libcudart")
if so_file is None:
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
assert so_file is not None, \
(
"libcudart is not loaded in the current process, "
"try setting VLLM_CUDART_SO_PATH"
)
assert so_file is not None, (
"libcudart is not loaded in the current process, "
"try setting VLLM_CUDART_SO_PATH"
)
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
@@ -154,27 +162,29 @@ class CudaRTLibrary:
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
count: int) -> None:
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
count: int) -> None:
def cudaMemcpy(
self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
def cudaIpcGetMemHandle(self,
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
ctypes.byref(handle), devPtr))
self.CUDART_CHECK(
self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
)
return handle
def cudaIpcOpenMemHandle(self,
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
self.CUDART_CHECK(
self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
)
)
return devPtr

View File

@@ -11,7 +11,9 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.all_reduce_utils import (
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
CUSTOM_ALL_REDUCE_MAX_SIZES,
gpu_p2p_access_check,
)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -32,8 +34,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
if i == rank:
continue
if envs.VLLM_SKIP_P2P_CHECK:
logger.info(
"Skipping P2P check and trusting the driver's P2P report.")
logger.info("Skipping P2P check and trusting the driver's P2P report.")
return torch.cuda.can_device_access_peer(rank, i)
if not gpu_p2p_access_check(rank, i):
return False
@@ -41,21 +42,23 @@ def _can_p2p(rank: int, world_size: int) -> bool:
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (inp.storage().nbytes() -
inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size())
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
# max_size: max supported allreduce size
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024,
symm_mem_enabled=False) -> None:
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024,
symm_mem_enabled=False,
) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
@@ -72,20 +75,24 @@ class CustomAllreduce:
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger.info("Custom allreduce is disabled because "
"of missing custom allreduce library")
logger.info(
"Custom allreduce is disabled because "
"of missing custom allreduce library"
)
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
"CustomAllreduce should be attached to a non-NCCL group."
)
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes.")
" spans across nodes."
)
return
rank = dist.get_rank(group=self.group)
@@ -100,7 +107,9 @@ class CustomAllreduce:
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
world_size,
str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
)
return
if isinstance(device, int):
@@ -110,13 +119,15 @@ class CustomAllreduce:
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
device_capability = current_platform.get_device_capability().as_version_str()
if (
current_platform.is_cuda()
and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES
):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size
)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
@@ -124,12 +135,9 @@ class CustomAllreduce:
device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(world_size)
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
@@ -138,13 +146,13 @@ class CustomAllreduce:
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
assert current_platform.is_cuda_alike()
fully_connected = current_platform.is_fully_connected(
physical_device_ids)
fully_connected = current_platform.is_fully_connected(physical_device_ids)
if world_size > 2 and not fully_connected:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly.")
"specify disable_custom_all_reduce=True explicitly."
)
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
@@ -154,16 +162,17 @@ class CustomAllreduce:
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.")
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Metadata composes of two parts: metadata for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group,
uncached=True)
self.meta_ptrs = self.create_shared_buffer(
ops.meta_size() + max_size, group=group, uncached=True
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
@@ -172,21 +181,22 @@ class CustomAllreduce:
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device=self.device)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.fully_connected = fully_connected
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
self.fully_connected)
self._ptr = ops.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.fully_connected
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
@@ -204,15 +214,13 @@ class CustomAllreduce:
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
@@ -233,13 +241,11 @@ class CustomAllreduce:
return inp_size < self.max_size
return False
def all_reduce(self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
registered: bool = False):
def all_reduce(
self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
@@ -249,8 +255,9 @@ class CustomAllreduce:
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
self.max_size)
ops.all_reduce(
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
@@ -283,9 +290,11 @@ class CustomAllreduce:
self.close()
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> list[int]:
def create_shared_buffer(
size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False,
) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
@@ -302,9 +311,11 @@ class CustomAllreduce:
return pointers
@staticmethod
def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = None) -> None:
def free_shared_buffer(
pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = None,
) -> None:
if rank is None:
rank = dist.get_rank(group=group)
if ops is not None:

View File

@@ -9,7 +9,6 @@ assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
class CustomCommunicator(CommBackend):
def __init__(self, group):
self._group = group
@@ -24,5 +23,5 @@ class CustomCommunicator(CommBackend):
dist.all_gather_object(gathered, data, group=self._group)
return gathered
def Split(self, color: int, key: int) -> 'CustomCommunicator':
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self

View File

@@ -10,8 +10,14 @@ from torch.distributed import ProcessGroup, ReduceOp
import vllm.envs as envs
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId)
NCCLLibrary,
buffer_type,
cudaStream_t,
ncclComm_t,
ncclDataTypeEnum,
ncclRedOpTypeEnum,
ncclUniqueId,
)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import current_stream
@@ -23,7 +29,8 @@ _NCCL_SYMM_OPS_REGISTERED = False
def register_nccl_symmetric_ops(pynccl_comm):
from vllm.distributed.device_communicators.pynccl_allocator import (
nccl_symm_mem_context)
nccl_symm_mem_context,
)
from vllm.utils import direct_register_custom_op
global _NCCL_SYMM_OPS_REGISTERED
@@ -31,8 +38,7 @@ def register_nccl_symmetric_ops(pynccl_comm):
return
_NCCL_SYMM_OPS_REGISTERED = True
def all_reduce_symmetric_with_copy_impl(
input_tensor: torch.Tensor) -> torch.Tensor:
def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor:
with nccl_symm_mem_context(pynccl_comm):
symm_input = torch.empty_like(input_tensor)
symm_output = torch.empty_like(input_tensor)
@@ -40,8 +46,7 @@ def register_nccl_symmetric_ops(pynccl_comm):
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
return symm_output
def all_reduce_symmetric_with_copy_fake(
input_tensor: torch.Tensor) -> torch.Tensor:
def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor:
return torch.empty_like(input_tensor)
direct_register_custom_op(
@@ -52,7 +57,6 @@ def register_nccl_symmetric_ops(pynccl_comm):
class PyNcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
@@ -73,7 +77,8 @@ class PyNcclCommunicator:
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.")
"PyNcclCommunicator should be attached to a non-NCCL group."
)
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
@@ -132,7 +137,8 @@ class PyNcclCommunicator:
# current cuda device to the specified one
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)
self.world_size, self.unique_id, self.rank
)
stream = current_stream()
# A small all_reduce for warmup.
@@ -141,11 +147,13 @@ class PyNcclCommunicator:
stream.synchronize()
del data
def all_reduce(self,
in_tensor: torch.Tensor,
out_tensor: torch.Tensor = None,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
def all_reduce(
self,
in_tensor: torch.Tensor,
out_tensor: torch.Tensor = None,
op: ReduceOp = ReduceOp.SUM,
stream=None,
) -> torch.Tensor:
if self.disabled:
return None
# nccl communicator created on a specific device
@@ -153,25 +161,28 @@ class PyNcclCommunicator:
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
f"but the input tensor is on {in_tensor.device}"
)
if out_tensor is None:
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
self.nccl.ncclAllReduce(
buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
return out_tensor
def all_gather(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None):
def all_gather(
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
):
if self.disabled:
return
# nccl communicator created on a specific device
@@ -179,14 +190,18 @@ class PyNcclCommunicator:
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = current_stream()
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream))
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def all_gatherv(
self,
@@ -202,14 +217,15 @@ class PyNcclCommunicator:
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = current_stream()
assert output_tensor.shape[0] == sum(sizes)
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset:split_offset + split_size]
dst_slice = output_tensor[split_offset : split_offset + split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
@@ -222,11 +238,13 @@ class PyNcclCommunicator:
split_offset += split_size
self.nccl.ncclGroupEnd()
def reduce_scatter(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
def reduce_scatter(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
@@ -234,15 +252,19 @@ class PyNcclCommunicator:
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = current_stream()
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def reduce_scatterv(
self,
@@ -259,20 +281,25 @@ class PyNcclCommunicator:
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = current_stream()
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset:split_offset + split_size, ...]
chunk = input_tensor[split_offset : split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()), chunk.numel(),
buffer_type(output_tensor.data_ptr()),
chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
cudaStream_t(stream.cuda_stream))
ncclRedOpTypeEnum.from_torch(op),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
@@ -281,31 +308,44 @@ class PyNcclCommunicator:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
)
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
cudaStream_t(stream.cuda_stream),
)
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = current_stream()
if src == self.rank:
@@ -315,9 +355,15 @@ class PyNcclCommunicator:
else:
sendbuff = buffer_type()
recvbuff = buffer_type(tensor.data_ptr())
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
self.nccl.ncclBroadcast(
sendbuff,
recvbuff,
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
cudaStream_t(stream.cuda_stream),
)
def group_start(self):
self.nccl.ncclGroupStart()
@@ -334,8 +380,7 @@ class PyNcclCommunicator:
)
def register_comm_window_raw(self, ptr: int, size: int):
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
size, 1)
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)

View File

@@ -98,7 +98,9 @@ def compile_nccl_allocator():
"This is expected if NCCL headers are not available. "
"optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory "
"containing the NCCL header. "
"Error: %s", str(e))
"Error: %s",
str(e),
)
def get_nccl_mem_pool():
@@ -125,21 +127,24 @@ atexit.register(_cleanup_nccl_allocator_wrapper)
class nccl_symm_mem_context:
def __init__(
self,
pynccl_comm: PyNcclCommunicator,
disabled: bool = False,
):
self.disabled = (disabled or not is_symmetric_memory_enabled()
or pynccl_comm.world_size == 1
or not current_platform.is_cuda()
or get_nccl_mem_pool() is None or version.parse(
torch.__version__) < version.parse("2.8.0.a0"))
self.disabled = (
disabled
or not is_symmetric_memory_enabled()
or pynccl_comm.world_size == 1
or not current_platform.is_cuda()
or get_nccl_mem_pool() is None
or version.parse(torch.__version__) < version.parse("2.8.0.a0")
)
if self.disabled:
self.pynccl_comm: Optional[PyNcclCommunicator] = None
self._mem_pool_ctx: contextlib.AbstractContextManager[
Any] = contextlib.nullcontext()
self._mem_pool_ctx: contextlib.AbstractContextManager[Any] = (
contextlib.nullcontext()
)
self.is_graph_capture = None
self.device = None
else:
@@ -151,16 +156,16 @@ class nccl_symm_mem_context:
def __enter__(self):
if self.disabled:
return self
assert (
self.pynccl_comm
is not None), "Symmetric memory requires pynccl to be initalized"
assert (
self.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
assert self.pynccl_comm is not None, (
"Symmetric memory requires pynccl to be initalized"
)
assert self.pynccl_comm.nccl_version >= 22703, (
"NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
)
if self.is_graph_capture:
assert (
_graph_pool_id
is not None), "graph_pool_id is not set under graph capture"
assert _graph_pool_id is not None, (
"graph_pool_id is not set under graph capture"
)
# Pause graph memory pool to use symmetric memory with cuda graph
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
self._mem_pool_ctx.__enter__()
@@ -179,8 +184,8 @@ class nccl_symm_mem_context:
for segment in _cached_pool_snapshot:
if segment["address"] not in _registered_base_addrs:
self.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"])
segment["address"], segment["total_size"]
)
_registered_base_addrs.add(segment["address"])
if self.is_graph_capture:
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id)
torch._C._cuda_beginAllocateCurrentThreadToPool(self.device, _graph_pool_id)

View File

@@ -133,88 +133,141 @@ class NCCLLibrary:
# const char* ncclGetErrorString(ncclResult_t result)
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
# ncclResult_t ncclGetVersion(int *version);
Function("ncclGetVersion", ncclResult_t,
[ctypes.POINTER(ctypes.c_int)]),
Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function("ncclGetUniqueId", ncclResult_t,
[ctypes.POINTER(ncclUniqueId)]),
Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function("ncclCommInitRank", ncclResult_t, [
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
ctypes.c_int
]),
Function(
"ncclCommInitRank",
ncclResult_t,
[ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],
),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
Function(
"ncclAllReduce",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
]),
Function(
"ncclReduce",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllGather", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclComm_t, cudaStream_t
]),
Function(
"ncclAllGather",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduceScatter", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
Function(
"ncclReduceScatter",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function("ncclSend", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
Function(
"ncclSend",
ncclResult_t,
[
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function("ncclRecv", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
Function(
"ncclRecv",
ncclResult_t,
[
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function("ncclBroadcast", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ctypes.c_int, ncclComm_t, cudaStream_t
]),
Function(
"ncclBroadcast",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
@@ -241,8 +294,7 @@ class NCCLLibrary:
),
# ncclResult_t ncclCommWindowDeregister(
# ncclComm_t comm, ncclWindow_t win);
Function("ncclCommWindowDeregister", ncclResult_t,
[ncclComm_t, ncclWindow_t]),
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
]
# class attribute to store the mapping from the path to the library
@@ -254,7 +306,6 @@ class NCCLLibrary:
path_to_dict_mapping: dict[str, dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_nccl_library()
try:
@@ -270,8 +321,10 @@ class NCCLLibrary:
"or it does not support the current platform %s. "
"If you already have the library, please set the "
"environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
" to point to the correct nccl library path.",
so_file,
platform.platform(),
)
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
@@ -284,15 +337,18 @@ class NCCLLibrary:
_funcs[func.name] = f
except AttributeError:
if func.name in [
"ncclCommWindowRegister",
"ncclCommWindowDeregister"
"ncclCommWindowRegister",
"ncclCommWindowDeregister",
]:
if envs.VLLM_USE_NCCL_SYMM_MEM:
logger.warning_once(
"The symbol %s is not found in the NCCL "
"library %s. To enable VLLM_USE_NCCL_SYMM_MEM "
" please update your NCCL version to >= "
"2.27.03.", func.name, so_file)
"2.27.03.",
func.name,
so_file,
)
if current_platform.is_rocm():
# Having an exception here on ROCm platform is
# not allowed during graph capturing
@@ -325,88 +381,153 @@ class NCCLLibrary:
def ncclGetUniqueId(self) -> ncclUniqueId:
unique_id = ncclUniqueId()
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
ctypes.byref(unique_id)))
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
return unique_id
def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
if len(data) != 128:
raise ValueError(
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes")
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes"
)
unique_id = ncclUniqueId()
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
return unique_id
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
rank: int) -> ncclComm_t:
def ncclCommInitRank(
self, world_size: int, unique_id: ncclUniqueId, rank: int
) -> ncclComm_t:
comm = ncclComm_t()
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
world_size, unique_id,
rank))
self.NCCL_CHECK(
self._funcs["ncclCommInitRank"](
ctypes.byref(comm), world_size, unique_id, rank
)
)
return comm
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
def ncclAllReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
self.NCCL_CHECK(
self._funcs["ncclAllReduce"](
sendbuff, recvbuff, count, datatype, op, comm, stream
)
)
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, root: int,
comm: ncclComm_t, stream: cudaStream_t) -> None:
def ncclReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
datatype, op, root, comm,
stream))
self.NCCL_CHECK(
self._funcs["ncclReduce"](
sendbuff, recvbuff, count, datatype, op, root, comm, stream
)
)
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
def ncclReduceScatter(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
count, datatype, op,
comm, stream))
self.NCCL_CHECK(
self._funcs["ncclReduceScatter"](
sendbuff, recvbuff, count, datatype, op, comm, stream
)
)
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
def ncclAllGather(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
datatype, comm, stream))
self.NCCL_CHECK(
self._funcs["ncclAllGather"](
sendbuff, recvbuff, count, datatype, comm, stream
)
)
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
dest, comm, stream))
def ncclSend(
self,
sendbuff: buffer_type,
count: int,
datatype: int,
dest: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)
)
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
comm, stream))
def ncclRecv(
self,
recvbuff: buffer_type,
count: int,
datatype: int,
src: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
)
def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, root: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count,
datatype, root, comm,
stream))
def ncclBroadcast(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclBroadcast"](
sendbuff, recvbuff, count, datatype, root, comm, stream
)
)
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
@@ -417,19 +538,27 @@ class NCCLLibrary:
def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type,
size: int, win_flags: int) -> ncclWindow_t:
def ncclCommWindowRegister(
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
) -> ncclWindow_t:
window = ncclWindow_t()
self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"](
comm, buff, size, ctypes.byref(window), win_flags))
self.NCCL_CHECK(
self._funcs["ncclCommWindowRegister"](
comm, buff, size, ctypes.byref(window), win_flags
)
)
return window
def ncclCommWindowDeregister(self, comm: ncclComm_t,
window: ncclWindow_t) -> None:
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
"ncclComm_t", "cudaStream_t", "buffer_type"
"NCCLLibrary",
"ncclDataTypeEnum",
"ncclRedOpTypeEnum",
"ncclUniqueId",
"ncclComm_t",
"cudaStream_t",
"buffer_type",
]

View File

@@ -27,9 +27,10 @@ except Exception:
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (inp.storage().nbytes() -
inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size())
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
class QuickReduceRegime(Enum):
@@ -44,7 +45,6 @@ MB = 1024 * 1024
class QuickAllReduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
# The following data is based on kernel tests.
@@ -58,20 +58,21 @@ class QuickAllReduce:
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
}
def __init__(self, group: ProcessGroup,
device: Union[int, str, torch.device]) -> None:
def __init__(
self, group: ProcessGroup, device: Union[int, str, torch.device]
) -> None:
"""
Custom allreduce provides non-destructive acceleration and is
Custom allreduce provides non-destructive acceleration and is
available for CUDA and ROCm MI300 series.
Custom quick allreduce leverages quantization for further
acceleration on ROCm. It currently supports Q8, Q6, and Q4
Custom quick allreduce leverages quantization for further
acceleration on ROCm. It currently supports Q8, Q6, and Q4
quantization formats and FP(float16, bfloat16).
Quick allreduce is designed as a complement to custom allreduce.
Its initialization requires even stricter conditions.
Quick allreduce is designed as a complement to custom allreduce.
Its initialization requires even stricter conditions.
Only the ROCm MI300 series is supported for quick allreduce at
Only the ROCm MI300 series is supported for quick allreduce at
this time.
Args:
@@ -93,18 +94,23 @@ class QuickAllReduce:
if not quick_ar:
# disable because of missing quick reduce library
# e.g. in a cuda environment
logger.info("Custom quick allreduce is disabled because "
"of missing custom quick allreduce library")
logger.info(
"Custom quick allreduce is disabled because "
"of missing custom quick allreduce library"
)
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"Custom quick allreduce should be attached to a non-NCCL group.")
"Custom quick allreduce should be attached to a non-NCCL group."
)
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom quick allreduce for
# multi-node case.
logger.warning("Custom quick allreduce is disabled because this "
"process group spans across nodes.")
logger.warning(
"Custom quick allreduce is disabled because this "
"process group spans across nodes."
)
return
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
@@ -118,7 +124,9 @@ class QuickAllReduce:
logger.warning(
"Custom quick allreduce is disabled due to an "
"unsupported world size: %d. Supported world sizes: %s.",
world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES))
world_size,
str(QuickAllReduce._SUPPORTED_WORLD_SIZES),
)
return
if isinstance(device, int):
@@ -134,9 +142,7 @@ class QuickAllReduce:
else:
device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(self.world_size)
@@ -148,12 +154,12 @@ class QuickAllReduce:
# where custom quick allreduce is not supported
# this checks hardware and driver support for NVLink
assert current_platform.is_cuda_alike()
self.fully_connected = current_platform.is_fully_connected(
physical_device_ids)
self.fully_connected = current_platform.is_fully_connected(physical_device_ids)
if self.world_size > 2 and not self.fully_connected:
logger.debug(
"Custom quick allreduce is disabled because it's not supported "
"on more than two PCIe-only GPUs. ")
"on more than two PCIe-only GPUs. "
)
return
self.init_quick_all_reduce()
@@ -169,24 +175,31 @@ class QuickAllReduce:
"Custom quick allreduce:",
f"Invalid quantization level: {regime_str}. "
"Supported levels: "
f"{list(QuickReduceRegime.__members__.keys())}")
f"{list(QuickReduceRegime.__members__.keys())}",
)
return
if regime_str == "NONE":
logger.debug("Custom quick allreduce is disabled based "
"on env variable "
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'")
logger.debug(
"Custom quick allreduce is disabled based "
"on env variable "
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'"
)
return
self.qr_quant_level = QuickReduceRegime[regime_str]
vllm_config = get_current_vllm_config()
if vllm_config is not None and \
hasattr(vllm_config, "model_config") and \
hasattr(vllm_config.model_config, "dtype"):
if (
vllm_config is not None
and hasattr(vllm_config, "model_config")
and hasattr(vllm_config.model_config, "dtype")
):
dtype = vllm_config.model_config.dtype
if dtype not in [torch.float16, torch.bfloat16]:
logger.debug(
"Custom quick allreduce disabled: only supports "
"float16 and float16, but get %s.", dtype)
"float16 and float16, but get %s.",
dtype,
)
return
if dtype == torch.bfloat16 and self.use_fp16_kernels:
@@ -194,7 +207,8 @@ class QuickAllReduce:
"Custom quick allreduce: BF16 inputs will be converted "
"to FP16 to improve performance. set "
"envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 "
"to turn off.")
"to turn off."
)
# VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
@@ -206,8 +220,7 @@ class QuickAllReduce:
)
qr_max_size = qr_max_size * MB
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
self.qr_max_size = qr_max_size if qr_max_size is not None \
else ops.qr_max_size()
self.qr_max_size = qr_max_size if qr_max_size is not None else ops.qr_max_size()
self.create_shared_buffer()
self.disabled = False
@@ -217,16 +230,15 @@ class QuickAllReduce:
try:
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
supported_archs = ['gfx94', 'gfx95']
supported_archs = ["gfx94", "gfx95"]
return any(gfx in gcn_arch for gfx in supported_archs)
except Exception as e:
logger.warning("Failed to determine ROCm for quick allreduce: %s",
e)
logger.warning("Failed to determine ROCm for quick allreduce: %s", e)
return False
def create_shared_buffer(self):
"""
Creates a shared buffer for quickreduce.
Creates a shared buffer for quickreduce.
Has to be called after init_custom_qr
"""
handle = ops.qr_get_handle(self._ptr)
@@ -253,9 +265,11 @@ class QuickAllReduce:
dtype = inp.dtype
if self.use_fp16_kernels:
dtype = torch.float16
return inp_size <= self.qr_max_size and \
inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\
[self.qr_quant_level.value]
return (
inp_size <= self.qr_max_size
and inp_size
>= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value]
)
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
"""Performs an out-of-place custom quick all reduce."""
@@ -263,8 +277,9 @@ class QuickAllReduce:
# as QR uses static IPC buffer.
if out is None:
out = torch.empty_like(inp)
ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value,
self.use_fp16_kernels)
ops.qr_all_reduce(
self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels
)
return out
def close(self):

View File

@@ -6,12 +6,12 @@ from typing import Any, Optional
import ray
import torch
from ray.exceptions import RayChannelError
from ray.experimental.channel.communicator import (Communicator,
TorchTensorAllocator)
from ray.experimental.channel.communicator import Communicator, TorchTensorAllocator
from torch.distributed import ReduceOp
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
DeviceCommunicatorBase,
)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.utils import current_stream
@@ -59,11 +59,11 @@ class RayPPCommunicator(Communicator):
self._rank: Optional[int] = None
self._actor_handles = actor_handles
if use_communication_streams:
raise NotImplementedError(
"use_communication_streams is not supported")
raise NotImplementedError("use_communication_streams is not supported")
if cuda_stream is not None and cuda_stream != current_stream():
raise ValueError(
"cuda_stream other than the current stream is not supported")
"cuda_stream other than the current stream is not supported"
)
if rank is not None:
# Rank is not None, this is Ray worker
@@ -99,13 +99,14 @@ class RayPPCommunicator(Communicator):
# Ray actor IDs are 32-character hex strings (128 bits)
ACTOR_ID_LEN = 32
actor_id_bytes = actor_id_str.encode('utf-8')
assert len(
actor_id_bytes
) == ACTOR_ID_LEN, f"Unexpected actor ID length: {len(actor_id_bytes)}"
actor_id_bytes = actor_id_str.encode("utf-8")
assert len(actor_id_bytes) == ACTOR_ID_LEN, (
f"Unexpected actor ID length: {len(actor_id_bytes)}"
)
actor_id_tensor = torch.frombuffer(
actor_id_bytes, dtype=torch.uint8).to(self._comm.device)
actor_id_tensor = torch.frombuffer(actor_id_bytes, dtype=torch.uint8).to(
self._comm.device
)
# All-gather full actor IDs from all actors
gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0)
@@ -115,9 +116,8 @@ class RayPPCommunicator(Communicator):
for rank in range(self._world_size):
start_idx = rank * ACTOR_ID_LEN
end_idx = (rank + 1) * ACTOR_ID_LEN
actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy(
).tobytes()
actor_id = actor_bytes.decode('utf-8')
actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy().tobytes()
actor_id = actor_bytes.decode("utf-8")
self._actor_id_to_rank[actor_id] = rank
def initialize(self, rank: int) -> None:
@@ -131,9 +131,10 @@ class RayPPCommunicator(Communicator):
"""
Return the given actor's rank using device communicator collective ops.
"""
assert hasattr(self, '_actor_id_to_rank'), (
assert hasattr(self, "_actor_id_to_rank"), (
"Actor rank mapping not built. "
"This should have been done during initialization.")
"This should have been done during initialization."
)
actor_id_str = actor._actor_id.hex()

View File

@@ -14,14 +14,24 @@ import torch
import torch.distributed as dist
import zmq
from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
from zmq import ( # type: ignore
IPV6, # type: ignore
SUB,
SUBSCRIBE,
XPUB,
XPUB_VERBOSE,
Context,
)
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
from vllm.logger import init_logger
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
is_valid_ipv6_address)
from vllm.utils import (
get_ip,
get_open_port,
get_open_zmq_ipc_path,
is_valid_ipv6_address,
)
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
@@ -29,7 +39,6 @@ logger = init_logger(__name__)
class SpinTimer:
def record_activity(self):
pass
@@ -66,12 +75,13 @@ class SpinSleepTimer(SpinTimer):
class ShmRingBuffer:
def __init__(self,
n_reader: int,
max_chunk_bytes: int,
max_chunks: int,
name: Optional[str] = None):
def __init__(
self,
n_reader: int,
max_chunk_bytes: int,
max_chunks: int,
name: Optional[str] = None,
):
"""
A shared memory ring buffer implementation for broadcast communication.
Essentially, it is a queue where only one will `enqueue` and multiple
@@ -120,13 +130,14 @@ class ShmRingBuffer:
created object to other processes by pickling it. The other processes will
get the name of the shared memory and open it, so that they can access the
same shared memory buffer.
"""# noqa
""" # noqa
self.n_reader = n_reader
self.metadata_size = 1 + n_reader
self.max_chunk_bytes = max_chunk_bytes
self.max_chunks = max_chunks
self.total_bytes_of_buffer = (self.max_chunk_bytes +
self.metadata_size) * self.max_chunks
self.total_bytes_of_buffer = (
self.max_chunk_bytes + self.metadata_size
) * self.max_chunks
self.data_offset = 0
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
@@ -134,10 +145,10 @@ class ShmRingBuffer:
# we are creating a buffer
self.is_creator = True
self.shared_memory = shared_memory.SharedMemory(
create=True, size=self.total_bytes_of_buffer)
create=True, size=self.total_bytes_of_buffer
)
# initialize the metadata section to 0
with self.shared_memory.buf[self.
metadata_offset:] as metadata_buffer:
with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
else:
# we are opening an existing buffer
@@ -145,8 +156,10 @@ class ShmRingBuffer:
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
with patch(
"multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None,
):
try:
self.shared_memory = shared_memory.SharedMemory(name=name)
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
@@ -154,8 +167,7 @@ class ShmRingBuffer:
# so the shared memory block size may be larger or equal
# to the requested size. The size parameter is ignored
# when attaching to an existing block.
assert (self.shared_memory.size
>= self.total_bytes_of_buffer)
assert self.shared_memory.size >= self.total_bytes_of_buffer
except FileNotFoundError:
# we might deserialize the object in a different node
# in this case, this object is not used,
@@ -163,8 +175,12 @@ class ShmRingBuffer:
pass
def handle(self):
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name)
return (
self.n_reader,
self.max_chunk_bytes,
self.max_chunks,
self.shared_memory.name,
)
def __reduce__(self):
return (
@@ -204,7 +220,6 @@ class Handle:
class MessageQueue:
def __init__(
self,
n_reader, # number of all readers
@@ -228,8 +243,7 @@ class MessageQueue:
# for local readers, we will:
# 1. create a shared memory ring buffer to communicate small data
# 2. create a publish-subscribe socket to communicate large data
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
@@ -279,8 +293,7 @@ class MessageQueue:
self.handle = Handle(
local_reader_ranks=local_reader_ranks,
buffer_handle=self.buffer.handle()
if self.buffer is not None else None,
buffer_handle=self.buffer.handle() if self.buffer is not None else None,
local_subscribe_addr=local_subscribe_addr,
remote_subscribe_addr=remote_subscribe_addr,
remote_addr_ipv6=remote_addr_ipv6,
@@ -315,8 +328,9 @@ class MessageQueue:
self.remote_socket = None
self._read_spin_timer = SpinSleepTimer(
) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
self._read_spin_timer = (
SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
)
else:
self.buffer = None # type: ignore
self.current_idx = -1
@@ -399,7 +413,8 @@ class MessageQueue:
" in %s seconds. This typically happens when some"
" processes are hanging or doing some"
" time-consuming work (e.g. compilation)",
VLLM_RINGBUFFER_WARNING_INTERVAL)
VLLM_RINGBUFFER_WARNING_INTERVAL,
)
n_warning += 1
continue
@@ -423,15 +438,16 @@ class MessageQueue:
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break
@contextmanager
def acquire_read(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
indefinite: bool = False):
def acquire_read(
self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
indefinite: bool = False,
):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
@@ -460,15 +476,16 @@ class MessageQueue:
raise TimeoutError
# if we wait for a long time, log a message
if not indefinite and (elapsed
> VLLM_RINGBUFFER_WARNING_INTERVAL *
n_warning):
if not indefinite and (
elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
):
logger.info(
"No available shared memory broadcast block found"
" in %s seconds. This typically happens when some"
" processes are hanging or doing some"
" time-consuming work (e.g. compilation).",
VLLM_RINGBUFFER_WARNING_INTERVAL)
VLLM_RINGBUFFER_WARNING_INTERVAL,
)
n_warning += 1
continue
@@ -480,14 +497,13 @@ class MessageQueue:
# caller has read from the buffer
# set the read flag
metadata_buffer[self.local_reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
self._read_spin_timer.record_activity()
break
def enqueue(self, obj, timeout: Optional[float] = None):
""" Write to message queue with optional timeout (in seconds) """
"""Write to message queue with optional timeout (in seconds)"""
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if self.n_local_reader > 0:
@@ -498,15 +514,17 @@ class MessageQueue:
else:
with self.acquire_write(timeout) as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
buf[1 : len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
def dequeue(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
indefinite: bool = False):
""" Read from message queue with optional timeout (in seconds) """
def dequeue(
self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
indefinite: bool = False,
):
"""Read from message queue with optional timeout (in seconds)"""
if self._is_local_reader:
with self.acquire_read(timeout, cancel, indefinite) as buf:
overflow = buf[0] == 1
@@ -539,11 +557,12 @@ class MessageQueue:
return self.dequeue()
@staticmethod
def create_from_process_group(pg: Union[ProcessGroup,
StatelessProcessGroup],
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "MessageQueue":
def create_from_process_group(
pg: Union[ProcessGroup, StatelessProcessGroup],
max_chunk_bytes,
max_chunks,
writer_rank=0,
) -> "MessageQueue":
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
@@ -554,6 +573,7 @@ class MessageQueue:
global_ranks = list(range(pg.world_size))
from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
@@ -570,17 +590,17 @@ class MessageQueue:
)
handle = buffer_io.export_handle()
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
dist.broadcast_object_list(
[handle], src=global_ranks[writer_rank], group=pg
)
else:
pg.broadcast_obj(handle, writer_rank)
else:
if isinstance(pg, ProcessGroup):
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
dist.broadcast_object_list(
recv, src=global_ranks[writer_rank], group=pg
)
handle = recv[0] # type: ignore
else:
handle = pg.broadcast_obj(None, writer_rank)

View File

@@ -24,63 +24,63 @@ class SingleWriterShmRingBuffer:
A single-writer, multiple-reader ring buffer implementation using shared
memory. This class provides a thread-safe ring buffer where one process
can write data while multiple processes/threads can read from it.
Architecture:
- Uses shared memory for cross-process communication
- Maintains metadata for each allocated buffer chunk in the writer process
- Supports custom "is_free_fn" functions to determine when buffers can be
reused
- Each buffer chunk contains: `[4-byte id][4-byte size][actual_data]`
Key Concepts:
- monotonic_id_start/end: Track the range of active buffer IDs
- data_buffer_start/end: Track the physical memory range in use
- Automatic wraparound when reaching buffer end
- Lazy garbage collection based on is_free_fn checks
Example Usage Scenarios:
Scenario 1: Simple Linear Allocation
```
Buffer size: 100 bytes
Initial state: [................................................. ]
^start=end(0)
After allocating 20 bytes (id=0):
[id:0|size:20|data........][...................................]
^start(0) ^end(28)
After allocating 30 bytes (id=1):
After allocating 30 bytes (id=1):
[id:0|size:20|data........][id:1|size:30|data..............][..]
^start(0) ^end(66)
```
Scenario 2: Memory Reclamation
```
Before freeing (both buffers still in use):
[id:0|size:20|data........][id:1|size:30|data..............][..]
^start(0) ^end(66)
After id:0 is marked free by readers:
[FREED.................... ][id:1|size:30|data..............][..]
^start(28) ^end(66)
After both are freed:
[FREED..............................................][..]
^start=end(66)
```
Scenario 3: Wraparound Allocation (continuing from Scenario 2)
```
Starting from after memory reclamation in Scenario 2:
[FREED..............................................][..]
^start=end(66)
Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound:
[id:2|size:40|data........................][FREED.............][..]
^end(148) ^start(66)
```
Scenario 4: Error Handling - Out of Space
```
Starting from after wraparound allocation in Scenario 3:
@@ -91,17 +91,17 @@ class SingleWriterShmRingBuffer:
occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100)
-> Raises MemoryError: "Not enough space in the data buffer"
```
Thread Safety:
- Single writer: Only one process/thread should write (allocate_buf)
- Multiple readers: Multiple processes/threads can read (access_buf)
- Multiple readers: Multiple processes/threads can read (access_buf)
- Reader synchronization handled by is_free_fn callback
- Writer handles garbage collection (free_buf) based on reader feedback
Memory Layout per Buffer Chunk:
`[4-byte monotonic_id][4-byte chunk_size][actual_data...]`
^metadata_start ^data_start
The monotonic_id ensures data integrity - readers can verify they're
accessing the correct data even after buffer wraparound or reuse.
"""
@@ -131,15 +131,16 @@ class SingleWriterShmRingBuffer:
self.monotonic_id_end: self.data_buffer_end
} # monotonic_id -> start address
self.shared_memory = shared_memory.SharedMemory(
create=True, size=self.data_buffer_size, name=name)
create=True, size=self.data_buffer_size, name=name
)
else:
# we are opening an existing buffer
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch(
"multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None,
"multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None,
):
self.shared_memory = shared_memory.SharedMemory(name=name)
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
@@ -149,8 +150,11 @@ class SingleWriterShmRingBuffer:
# when attaching to an existing block.
assert self.shared_memory.size >= self.data_buffer_size
logger.debug("Shared memory created/opened with name: %s, size: %d",
self.shared_memory.name, self.data_buffer_size)
logger.debug(
"Shared memory created/opened with name: %s, size: %d",
self.shared_memory.name,
self.data_buffer_size,
)
def handle(self):
return (
@@ -182,19 +186,20 @@ class SingleWriterShmRingBuffer:
return int.from_bytes(byte_data, "little", signed=True)
def allocate_buf(self, size: int) -> tuple[int, int]:
'''
"""
Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory.
Memory layout:
`[4-byte monotonic_id][4-byte size][buffer data...]`
'''
"""
assert self.is_writer, "Only the writer can allocate buffers."
assert size > 0, "Size must be greater than 0"
size += self.MD_SIZE # add metadata size to the buffer size
# reset to beginning if the buffer does have enough contiguous space
buffer_end_reset = self.data_buffer_end % self.data_buffer_size
if buffer_end_reset + size > self.data_buffer_size:
buffer_end_reset = (self.data_buffer_end // self.data_buffer_size +
1) * self.data_buffer_size
buffer_end_reset = (
self.data_buffer_end // self.data_buffer_size + 1
) * self.data_buffer_size
else: # no reset needed
buffer_end_reset = self.data_buffer_end
@@ -203,21 +208,24 @@ class SingleWriterShmRingBuffer:
# exceeds the start of the data buffer
occupied_size_new = buffer_end_reset + size - self.data_buffer_start
if occupied_size_new > self.data_buffer_size:
raise MemoryError("Not enough space in the data buffer, "
"try calling free_buf() to free up space")
raise MemoryError(
"Not enough space in the data buffer, "
"try calling free_buf() to free up space"
)
self.data_buffer_end = buffer_end_reset
# first 4 bytes as the monotonic id
buf_idx = self.data_buffer_end % self.data_buffer_size
self.shared_memory.buf[buf_idx:buf_idx + self.ID_NBYTES] = \
self.int2byte(self.monotonic_id_end)
self.shared_memory.buf[buf_idx : buf_idx + self.ID_NBYTES] = self.int2byte(
self.monotonic_id_end
)
# next 4 bytes as the size of the data buffer
self.shared_memory.buf[buf_idx + self.ID_NBYTES: \
buf_idx + self.MD_SIZE] = self.int2byte(size)
self.shared_memory.buf[buf_idx + self.ID_NBYTES : buf_idx + self.MD_SIZE] = (
self.int2byte(size)
)
# record metadata
self.metadata[self.monotonic_id_end %
self.ID_MAX] = self.data_buffer_end
self.metadata[self.monotonic_id_end % self.ID_MAX] = self.data_buffer_end
# update buffer and monotonic id indices
current_buffer_end = self.data_buffer_end
current_id_end = self.monotonic_id_end
@@ -230,23 +238,26 @@ class SingleWriterShmRingBuffer:
buf_idx = address % self.data_buffer_size
# read metadata
metadata_buff = self.shared_memory.buf[buf_idx:buf_idx + self.MD_SIZE]
id = self.byte2int(metadata_buff[:self.ID_NBYTES])
size = self.byte2int(metadata_buff[self.ID_NBYTES:self.MD_SIZE])
metadata_buff = self.shared_memory.buf[buf_idx : buf_idx + self.MD_SIZE]
id = self.byte2int(metadata_buff[: self.ID_NBYTES])
size = self.byte2int(metadata_buff[self.ID_NBYTES : self.MD_SIZE])
# yield the data buffer and metadata
data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE:buf_idx +
size]
with (memoryview(data_buff) as data_view, ):
data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE : buf_idx + size]
with (
memoryview(data_buff) as data_view,
):
yield data_view, (id, size)
def free_buf(self,
is_free_fn: Callable[[int, memoryview], bool],
nbytes: Optional[int] = None) -> Iterable[int]:
'''
def free_buf(
self,
is_free_fn: Callable[[int, memoryview], bool],
nbytes: Optional[int] = None,
) -> Iterable[int]:
"""
Free a buffer of the given size. This is a no-op in shared memory,
but we need to keep track of the metadata.
If freed memory spreads across the end and start of the ring buffer,
the actual freed memory will be in two segments. In this case there
still might not be a contiguous space of `nbytes` available.
@@ -254,13 +265,15 @@ class SingleWriterShmRingBuffer:
Args:
nbytes (int, optional): The size of the buffer to free. If None,
frees the maximum size of the ring buffer.
'''
"""
assert self.is_writer, "Only the writer can free buffers."
logger.debug(
"Freeing up space in the ring buffer, "
"monotonic_id_start: %d, monotonic_id_end: %d",
self.monotonic_id_start, self.monotonic_id_end)
self.monotonic_id_start,
self.monotonic_id_end,
)
monotonic_id_before = self.monotonic_id_start
# if nbytes is None, free up the maximum size of the ring buffer
if nbytes is None:
@@ -272,8 +285,9 @@ class SingleWriterShmRingBuffer:
if is_free_fn(self.monotonic_id_start, data_buff):
# check passed, we can free the buffer
del self.metadata[self.monotonic_id_start]
self.monotonic_id_start = ((self.monotonic_id_start + 1) %
self.ID_MAX)
self.monotonic_id_start = (
self.monotonic_id_start + 1
) % self.ID_MAX
self.data_buffer_start = address
freed_bytes += metadata[1]
else:
@@ -282,8 +296,11 @@ class SingleWriterShmRingBuffer:
logger.debug(
"Freed %d bytes from the ring buffer, "
"monotonic_id_start: %d, monotonic_id_end: %d", freed_bytes,
self.monotonic_id_start, self.monotonic_id_end)
"monotonic_id_start: %d, monotonic_id_end: %d",
freed_bytes,
self.monotonic_id_start,
self.monotonic_id_end,
)
# buffer wrap around
if self.data_buffer_start >= self.data_buffer_size:
@@ -295,12 +312,12 @@ class SingleWriterShmRingBuffer:
if monotonic_id_after >= monotonic_id_before:
return range(monotonic_id_before, monotonic_id_after)
else:
return chain(range(monotonic_id_before, self.ID_MAX),
range(0, monotonic_id_after))
return chain(
range(monotonic_id_before, self.ID_MAX), range(0, monotonic_id_after)
)
class ObjectSerde(ABC):
@abstractmethod
def serialize(self, value: Any) -> tuple[Any, int, bytes, int]:
"""Serialize an object to bytes."""
@@ -313,7 +330,6 @@ class ObjectSerde(ABC):
class MsgpackSerde(ObjectSerde):
def __init__(self):
# Delayed import to avoid circular dependency
from vllm.multimodal.inputs import MultiModalKwargsItem
@@ -325,8 +341,8 @@ class MsgpackSerde(ObjectSerde):
self._mm_kwargs_item_cls = MultiModalKwargsItem
def serialize(
self,
value: Any) -> tuple[Union[bytes, list[bytes]], int, bytes, int]:
self, value: Any
) -> tuple[Union[bytes, list[bytes]], int, bytes, int]:
len_arr = None
if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)):
type_name = type(value).__name__
@@ -339,8 +355,9 @@ class MsgpackSerde(ObjectSerde):
nbytes = len(value)
object_metadata = (type_name, nbytes, len_arr)
serialized_metadata = pickle.dumps(object_metadata,
protocol=pickle.HIGHEST_PROTOCOL)
serialized_metadata = pickle.dumps(
object_metadata, protocol=pickle.HIGHEST_PROTOCOL
)
return value, nbytes, serialized_metadata, len(serialized_metadata)
def deserialize(self, data_view: memoryview) -> Any:
@@ -353,7 +370,7 @@ class MsgpackSerde(ObjectSerde):
obj = []
start_idx = 0
for length in len_arr:
item_bytes = serialized_data[start_idx:start_idx + length]
item_bytes = serialized_data[start_idx : start_idx + length]
obj.append(item_bytes)
start_idx += length
obj = self.tensor_decoder.decode(obj)
@@ -361,15 +378,14 @@ class MsgpackSerde(ObjectSerde):
obj = []
start_idx = 0
for length in len_arr:
item_bytes = serialized_data[start_idx:start_idx + length]
item_bytes = serialized_data[start_idx : start_idx + length]
obj.append(item_bytes)
start_idx += length
obj = self.mm_decoder.decode(obj)
elif type_name == bytes.__name__:
obj = pickle.loads(serialized_data)
else:
raise ValueError(
f"Unsupported object type '{type_name}' in metadata")
raise ValueError(f"Unsupported object type '{type_name}' in metadata")
return obj
@@ -388,18 +404,18 @@ class SingleWriterShmObjectStorage:
A single-writer, multiple-reader object storage system built on top of a
shared memory ring buffer. Provides key-value storage with automatic memory
management and cross-process serialization support.
This storage system follows a FIFO (First-In-First-Out) eviction policy
where the oldest objects are automatically freed when memory runs low.
Memory is reclaimed based on reader reference counting - objects are only
freed when all readers have finished accessing them.
Architecture:
- Single writer process can put(key, value) objects
- Multiple reader processes can get(address, monotonic_id) objects
- Built on SingleWriterShmRingBuffer for efficient shared memory management
- Thread-safe operations with reader synchronization via locks
Key Features:
- FIFO Eviction: Oldest objects are evicted first when memory is full
- Reference Counting: Objects are only freed when no readers are
@@ -414,7 +430,7 @@ class SingleWriterShmObjectStorage:
Memory Layout per Object:
`[4-byte reference_count][metadata_size][serialized_object_data]`
Thread Safety:
- Writer operations (put, clear) are single-threaded by design
- Reader operations (get) are thread-safe with lock-based reference
@@ -482,18 +498,17 @@ class SingleWriterShmObjectStorage:
md_bytes: int,
data_view: memoryview,
) -> None:
data_view[self.flag_bytes:self.flag_bytes + md_bytes] = metadata
data_view[self.flag_bytes : self.flag_bytes + md_bytes] = metadata
if isinstance(data, bytes):
data_view[-data_bytes:] = data
elif isinstance(data, list):
start_idx = self.flag_bytes + md_bytes
for item_bytes in data:
item_size = len(item_bytes)
data_view[start_idx:start_idx + item_size] = item_bytes
data_view[start_idx : start_idx + item_size] = item_bytes
start_idx += item_size
else:
raise ValueError(
f"Unsupported data type for serialization: {type(data)}")
raise ValueError(f"Unsupported data type for serialization: {type(data)}")
def increment_writer_flag(self, id: int) -> None:
"""Set the in-use flag for the writer."""
@@ -509,8 +524,9 @@ class SingleWriterShmObjectStorage:
"""Free unused buffers in the ring buffer."""
# try to free up 2*max_object_size bytes of space in the ring buffer,
# since the buffer might be fragmented
freed_ids = self.ring_buffer.free_buf(self.default_is_free_check,
2 * self.max_object_size)
freed_ids = self.ring_buffer.free_buf(
self.default_is_free_check, 2 * self.max_object_size
)
# update the metadata after freeing up space
for freed_id in freed_ids:
key_to_free = self.id_index[freed_id]
@@ -537,7 +553,7 @@ class SingleWriterShmObjectStorage:
Store a key-value pair in the object storage.
Attempts to free max_object_size bytes using FIFO order
when the ring buffer runs out of space during a put() operation.
Args:
key: String key to identify the object
value: Any serializable Python object
@@ -550,15 +566,17 @@ class SingleWriterShmObjectStorage:
if key in self.key_index:
raise ValueError(f"Key '{key}' already exists in the storage.")
object_data, data_bytes, object_metadata, md_bytes = \
self.ser_de.serialize(value)
object_data, data_bytes, object_metadata, md_bytes = self.ser_de.serialize(
value
)
buffer_size = self.flag_bytes + data_bytes + md_bytes
# Sanity checks
if buffer_size > self.max_object_size:
raise ValueError(
f"Serialized object size ({buffer_size} bytes) exceeds "
f"max object size ({self.max_object_size} bytes)")
f"max object size ({self.max_object_size} bytes)"
)
# Allocate new buffer
try:
@@ -570,9 +588,10 @@ class SingleWriterShmObjectStorage:
# Write data to buffer
with self.ring_buffer.access_buf(address) as (data_view, metadata):
data_view[:self.flag_bytes] = self.ring_buffer.int2byte(0)
self.copy_to_buffer(object_data, data_bytes, object_metadata,
md_bytes, data_view)
data_view[: self.flag_bytes] = self.ring_buffer.int2byte(0)
self.copy_to_buffer(
object_data, data_bytes, object_metadata, md_bytes, data_view
)
self.increment_writer_flag(monotonic_id)
# Update key index
@@ -587,14 +606,15 @@ class SingleWriterShmObjectStorage:
if buf_metadata[0] != monotonic_id:
raise ValueError(
f"Data for address:id '{address}:{monotonic_id}'"
" has been modified or is invalid.")
" has been modified or is invalid."
)
obj = self.ser_de.deserialize(data_view[self.flag_bytes:])
obj = self.ser_de.deserialize(data_view[self.flag_bytes :])
# decrease the in-use flag for reader reads
if self._reader_lock is not None:
with self._reader_lock:
self.increment_reader_flag(data_view[:self.flag_bytes])
self.increment_reader_flag(data_view[: self.flag_bytes])
else:
# if self._reader_lock is None, it means we are the writer
# in this case, we do not need to decrease the reader count
@@ -614,7 +634,8 @@ class SingleWriterShmObjectStorage:
@staticmethod
def create_from_handle(
handle: ShmObjectStorageHandle) -> "SingleWriterShmObjectStorage":
handle: ShmObjectStorageHandle,
) -> "SingleWriterShmObjectStorage":
logger.debug("Creating storage from handle: %s", handle)
ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle)
return SingleWriterShmObjectStorage(

View File

@@ -7,7 +7,8 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -28,20 +29,20 @@ class SymmMemCommunicator:
}
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
# add options for testing
force_multimem: Optional[bool] = None,
max_size_override: Optional[int] = None):
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
# add options for testing
force_multimem: Optional[bool] = None,
max_size_override: Optional[int] = None,
):
self.disabled = True
if not symm_mem_available:
return
if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
logger.warning("SymmMemCommunicator: symmetric memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
@@ -52,8 +53,9 @@ class SymmMemCommunicator:
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = current_platform.get_device_capability(
).as_version_str()
self.device_capability = (
current_platform.get_device_capability().as_version_str()
)
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
@@ -61,8 +63,7 @@ class SymmMemCommunicator:
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability]:
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
@@ -77,8 +78,9 @@ class SymmMemCommunicator:
self.max_size,
)
else:
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability][self.world_size]
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size
]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
@@ -87,8 +89,10 @@ class SymmMemCommunicator:
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
logger.warning(
"SymmMemCommunicator: symmetric memory "
"multicast operations are not supported."
)
return
self.force_multimem = force_multimem
self.disabled = False
@@ -104,15 +108,13 @@ class SymmMemCommunicator:
return inp_size < self.max_size
def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
self.buffer[: inp.numel()].copy_(inp.view(-1))
# Determine which algorithm to use
use_multimem = False
@@ -121,16 +123,17 @@ class SymmMemCommunicator:
use_multimem = self.force_multimem
else:
# Normal logic: use multimem for supported world sizes
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]
use_multimem = (
self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]
)
if use_multimem:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
torch.ops.symm_mem.multimem_all_reduce_(
self.buffer[: inp.numel()], "sum", self.group.group_name
)
else:
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
torch.ops.symm_mem.two_shot_all_reduce_(
self.buffer[: inp.numel()], "sum", self.group.group_name
)
out.copy_(self.buffer[: inp.numel()].view(out.shape))
return out

View File

@@ -14,8 +14,9 @@ from vllm.platforms.tpu import USE_TPU_COMMONS
from .base_device_communicator import DeviceCommunicatorBase
USE_RAY = parallel_config = get_current_vllm_config(
).parallel_config.distributed_executor_backend == "ray"
USE_RAY = parallel_config = (
get_current_vllm_config().parallel_config.distributed_executor_backend == "ray"
)
logger = init_logger(__name__)
@@ -27,18 +28,21 @@ if not USE_TPU_COMMONS:
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
create_optimized_replica_groups,
)
if USE_RAY:
from vllm.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
@@ -98,5 +102,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator as TpuCommonsCommunicator,
)
TpuCommunicator = TpuCommonsCommunicator # type: ignore

View File

@@ -16,12 +16,13 @@ logger = init_logger(__name__)
class XpuCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
def __init__(
self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
@@ -29,10 +30,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
logger.warning(
"`%s` all2all manager is not supported on XPU."
"Falling back to `naive` all2all manager for XPU.",
all2all_backend)
all2all_backend,
)
all2all_backend = "naive"
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
@@ -40,12 +43,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
dist.all_reduce(input_, group=self.device_group)
return input_
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
@@ -53,20 +56,19 @@ class XpuCommunicator(DeviceCommunicatorBase):
# cluster so we use all_gather instead for now.
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((self.world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
output_tensor = torch.empty(
(self.world_size,) + input_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
if self.rank_in_group == dst:
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
output_tensor = output_tensor.reshape(
input_size[:dim]
+ (self.world_size * input_size[dim],)
+ input_size[dim + 1 :]
)
else:
output_tensor = None
return output_tensor
@@ -78,17 +80,19 @@ class XpuCommunicator(DeviceCommunicatorBase):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel)
hidden_states, router_logits, is_sequence_parallel
)
return hidden_states, router_logits
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
)
return hidden_states