[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)
This commit is contained in:
@@ -6,6 +6,7 @@ import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import (
|
||||
@@ -15,7 +16,11 @@ from vllm.logger import init_logger
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
from vllm._C import custom_ar
|
||||
# Simulate ImportError if custom_ar ops are not supported.
|
||||
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
|
||||
raise ImportError("custom_ar", __file__)
|
||||
|
||||
custom_ar = True
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
@@ -27,7 +32,7 @@ try:
|
||||
|
||||
except ImportError:
|
||||
# For AMD GPUs
|
||||
custom_ar = None
|
||||
custom_ar = False
|
||||
pynvml = None
|
||||
|
||||
@contextmanager
|
||||
@@ -97,7 +102,7 @@ class CustomAllreduce:
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if custom_ar is None:
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-cuda environment
|
||||
return
|
||||
@@ -175,7 +180,7 @@ class CustomAllreduce:
|
||||
# meta data composes of two parts: meta data for synchronization
|
||||
# (256 bytes) and a temporary buffer for storing intermediate
|
||||
# allreduce results.
|
||||
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
|
||||
self.meta = torch.zeros(ops.meta_size() + max_size,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
@@ -196,9 +201,8 @@ class CustomAllreduce:
|
||||
self.world_size = world_size
|
||||
handles, offsets = self._get_ipc_meta(self.meta)
|
||||
self.full_nvlink = full_nvlink
|
||||
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
|
||||
handles, offsets, rank,
|
||||
self.full_nvlink)
|
||||
self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles,
|
||||
offsets, rank, self.full_nvlink)
|
||||
self.register_buffer(self.buffer)
|
||||
|
||||
@contextmanager
|
||||
@@ -252,31 +256,31 @@ class CustomAllreduce:
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
handles, offsets = self._get_ipc_meta(inp)
|
||||
custom_ar.register_buffer(self._ptr, inp, handles, offsets)
|
||||
ops.register_buffer(self._ptr, inp, handles, offsets)
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
custom_ar.register_graph_buffers(self._ptr, handles, offsets)
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
|
||||
self.full_nvlink)
|
||||
return ops.should_custom_ar(inp, self.max_size, self.world_size,
|
||||
self.full_nvlink)
|
||||
|
||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||
# or, in the context of cuda graphs, register_graph_buffers
|
||||
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
custom_ar.all_reduce_reg(self._ptr, inp, out)
|
||||
ops.all_reduce_reg(self._ptr, inp, out)
|
||||
return out
|
||||
|
||||
# all reduce, assuming inp tensor is NOT IPC registered
|
||||
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
@@ -304,7 +308,7 @@ class CustomAllreduce:
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
custom_ar.dispose(self._ptr)
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
|
||||
def __del__(self):
|
||||
|
||||
Reference in New Issue
Block a user