[V1] [Feature] Collective RPC (#15444)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@@ -8,7 +8,7 @@ import time
|
||||
from concurrent.futures import Future
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import msgspec
|
||||
import psutil
|
||||
@@ -43,6 +43,8 @@ logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_S = 2.5
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
|
||||
class EngineCore:
|
||||
"""Inner loop of vLLM's Engine."""
|
||||
@@ -280,6 +282,14 @@ class EngineCore:
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.model_executor.collective_rpc(method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
@@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@@ -33,6 +33,8 @@ logger = init_logger(__name__)
|
||||
|
||||
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
|
||||
class EngineCoreClient(ABC):
|
||||
"""
|
||||
@@ -117,6 +119,13 @@ class EngineCoreClient(ABC):
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -153,6 +162,14 @@ class EngineCoreClient(ABC):
|
||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def collective_rpc_async(
|
||||
self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InprocClient(EngineCoreClient):
|
||||
"""
|
||||
@@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient):
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
@@ -505,6 +529,14 @@ class SyncMPClient(MPClient):
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.call_utility("execute_dummy_batch")
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.call_utility("collective_rpc", method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
@@ -636,6 +668,15 @@ class AsyncMPClient(MPClient):
|
||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||
return await self.call_utility_async("pin_lora", lora_id)
|
||||
|
||||
async def collective_rpc_async(
|
||||
self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return await self.call_utility_async("collective_rpc", method, timeout,
|
||||
args, kwargs)
|
||||
|
||||
|
||||
class DPAsyncMPClient(AsyncMPClient):
|
||||
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
@@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
@@ -282,6 +283,13 @@ class LLMEngine:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def __del__(self):
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
Reference in New Issue
Block a user