[core] allow callable in collective_rpc (#12151)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-17 20:47:01 +08:00
committed by GitHub
parent d4e6194570
commit 87a0c076af
13 changed files with 147 additions and 50 deletions

View File

@@ -1,8 +1,8 @@
import itertools
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload)
import cloudpickle
from tqdm import tqdm
@@ -464,7 +464,7 @@ class LLM:
return self.engine_class.validate_outputs(outputs, RequestOutput)
def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
@@ -476,9 +476,13 @@ class LLM:
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
The method can also be a callable, which will be serialized
and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
"""
return self.llm_engine.model_executor.collective_rpc(
method, timeout, args, kwargs)
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
def beam_search(
self,