[Core] Interface for accessing model from VllmRunner (#10353)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,8 +5,9 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
|
||||
Tuple, Type, Union, cast, overload)
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import deprecated
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
|
||||
from vllm import envs
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
@@ -42,6 +43,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class LLM:
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
@@ -464,25 +467,42 @@ class LLM:
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
|
||||
"""
|
||||
Run a method on all workers, with homogeneous arguments.
|
||||
The main extension point for the LLM entrypoint.
|
||||
Users can provide custom worker class through `worker_cls`
|
||||
argument, and implement new methods in the worker class.
|
||||
Then, users can call the new methods through this API.
|
||||
It is recommended to use this API to only pass control messages,
|
||||
and set up data-plane communication to pass data.
|
||||
The method can also be a callable, which will be serialized
|
||||
and sent to all workers to execute.
|
||||
If the method is a callable, it should accept an additional
|
||||
`self` argument, in addition to the arguments passed in `args`
|
||||
and `kwargs`. The `self` argument will be the worker object.
|
||||
Execute an RPC call on all workers.
|
||||
|
||||
Args:
|
||||
method: Name of the worker method to execute, or a callable that
|
||||
is serialized and sent to all workers to execute.
|
||||
|
||||
If the method is a callable, it should accept an additional
|
||||
`self` argument, in addition to the arguments passed in `args`
|
||||
and `kwargs`. The `self` argument will be the worker object.
|
||||
timeout: Maximum time in seconds to wait for execution. Raises a
|
||||
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
|
||||
args: Positional arguments to pass to the worker method.
|
||||
kwargs: Keyword arguments to pass to the worker method.
|
||||
|
||||
Returns:
|
||||
A list containing the results from each worker.
|
||||
|
||||
Note:
|
||||
It is recommended to use this API to only pass control messages,
|
||||
and set up data-plane communication to pass data.
|
||||
"""
|
||||
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
|
||||
executor = self.llm_engine.model_executor
|
||||
return executor.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
"""
|
||||
Run a function directly on the model inside each worker,
|
||||
returning the result for each of them.
|
||||
"""
|
||||
executor = self.llm_engine.model_executor
|
||||
return executor.apply_model(func)
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user