[core] LLM.collective_rpc interface and RLHF example (#12084)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-16 20:19:52 +08:00
committed by GitHub
parent bf53e0c70b
commit 92e793d91a
6 changed files with 270 additions and 35 deletions

View File

@@ -4,6 +4,7 @@ from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)
import cloudpickle
from tqdm import tqdm
from typing_extensions import deprecated
@@ -186,6 +187,13 @@ class LLM:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
if "worker_cls" in kwargs:
worker_cls = kwargs["worker_cls"]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if isinstance(worker_cls, type):
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
@@ -455,6 +463,23 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
Run a method on all workers, with homogeneous arguments.
The main extension point for the LLM entrypoint.
Users can provide custom worker class through `worker_cls`
argument, and implement new methods in the worker class.
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
return self.llm_engine.model_executor.collective_rpc(
method, timeout, args, kwargs)
def beam_search(
self,
prompts: List[Union[TokensPrompt, TextPrompt]],