[Core] Interface for accessing model from VllmRunner (#10353)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -3,6 +3,9 @@ from abc import ABC, abstractmethod
|
||||
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
||||
Union)
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -11,9 +14,12 @@ from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import make_async
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class ExecutorBase(ABC):
|
||||
"""Base class for all executors.
|
||||
@@ -44,22 +50,37 @@ class ExecutorBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _init_executor(self) -> None:
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
|
||||
"""
|
||||
The main interface of the executor to run a method on all workers,
|
||||
with homogeneous arguments.
|
||||
If the args are heterogeneous, then we can pack them into a list,
|
||||
and unpack them in the method of every worker, because every worker
|
||||
knows their own rank.
|
||||
Execute an RPC call on all workers.
|
||||
|
||||
Args:
|
||||
method: Name of the worker method to execute, or a callable that
|
||||
is serialized and sent to all workers to execute.
|
||||
|
||||
If the method is a callable, it should accept an additional
|
||||
`self` argument, in addition to the arguments passed in `args`
|
||||
and `kwargs`. The `self` argument will be the worker object.
|
||||
timeout: Maximum time in seconds to wait for execution. Raises a
|
||||
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
|
||||
args: Positional arguments to pass to the worker method.
|
||||
kwargs: Keyword arguments to pass to the worker method.
|
||||
|
||||
Returns:
|
||||
A list containing the results from each worker.
|
||||
|
||||
Note:
|
||||
It is recommended to use this API to only pass control messages,
|
||||
and set up data-plane communication to pass data.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
@@ -97,6 +118,17 @@ class ExecutorBase(ABC):
|
||||
self.collective_rpc("initialize_cache",
|
||||
args=(num_gpu_blocks, num_cpu_blocks))
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
"""
|
||||
Run a function directly on the model inside each worker,
|
||||
returning the result for each of them.
|
||||
"""
|
||||
|
||||
def rpc_func(worker: WorkerBase) -> _R:
|
||||
return func(worker.get_model())
|
||||
|
||||
return self.collective_rpc(rpc_func)
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||
|
||||
@@ -148,7 +148,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
) -> List[Any]:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user