[core] allow callable in collective_rpc (#12151)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -6,9 +6,11 @@ import time
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cloudpickle
|
||||
import psutil
|
||||
import zmq
|
||||
|
||||
@@ -120,7 +122,7 @@ class MultiprocExecutor(Executor):
|
||||
return kv_cache_specs[0]
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
@@ -141,7 +143,12 @@ class MultiprocExecutor(Executor):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
try:
|
||||
self.rpc_broadcast_mq.enqueue((method, args, kwargs))
|
||||
if isinstance(method, str):
|
||||
send_method = method
|
||||
else:
|
||||
send_method = cloudpickle.dumps(
|
||||
method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
|
||||
|
||||
responses = [None] * self.world_size
|
||||
for w in self.workers:
|
||||
@@ -408,7 +415,11 @@ class WorkerProc:
|
||||
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
|
||||
|
||||
try:
|
||||
output = getattr(self.worker, method)(*args, **kwargs)
|
||||
if isinstance(method, str):
|
||||
func = getattr(self.worker, method)
|
||||
elif isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), self.worker)
|
||||
output = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, e))
|
||||
|
||||
Reference in New Issue
Block a user