[Core] Interface for accessing model from VllmRunner (#10353)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-20 15:00:59 +08:00
committed by GitHub
parent 83609791d2
commit 59a0192fb9
35 changed files with 460 additions and 293 deletions

View File

@@ -3,6 +3,9 @@ from abc import ABC, abstractmethod
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)
import torch.nn as nn
from typing_extensions import TypeVar
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -11,9 +14,12 @@ from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class ExecutorBase(ABC):
"""Base class for all executors.
@@ -44,22 +50,37 @@ class ExecutorBase(ABC):
@abstractmethod
def _init_executor(self) -> None:
pass
raise NotImplementedError
@abstractmethod
def collective_rpc(self,
method: Union[str, Callable],
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
"""
The main interface of the executor to run a method on all workers,
with homogeneous arguments.
If the args are heterogeneous, then we can pack them into a list,
and unpack them in the method of every worker, because every worker
knows their own rank.
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
pass
raise NotImplementedError
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
@@ -97,6 +118,17 @@ class ExecutorBase(ABC):
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
def rpc_func(worker: WorkerBase) -> _R:
return func(worker.get_model())
return self.collective_rpc(rpc_func)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:

View File

@@ -148,7 +148,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
) -> List[Any]:
"""Runs the given method on all workers.
Args: