[core] allow callable in collective_rpc (#12151)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-17 20:47:01 +08:00
committed by GitHub
parent d4e6194570
commit 87a0c076af
13 changed files with 147 additions and 50 deletions

View File

@@ -6,9 +6,11 @@ import time
import weakref
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from multiprocessing.process import BaseProcess
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import cloudpickle
import psutil
import zmq
@@ -120,7 +122,7 @@ class MultiprocExecutor(Executor):
return kv_cache_specs[0]
def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
@@ -141,7 +143,12 @@ class MultiprocExecutor(Executor):
kwargs = kwargs or {}
try:
self.rpc_broadcast_mq.enqueue((method, args, kwargs))
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
responses = [None] * self.world_size
for w in self.workers:
@@ -408,7 +415,11 @@ class WorkerProc:
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
try:
output = getattr(self.worker, method)(*args, **kwargs)
if isinstance(method, str):
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs)
except Exception as e:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e))