[V1] [Feature] Collective RPC (#15444)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
wwl2755
2025-03-29 05:39:14 -05:00
committed by GitHub
parent 4965ec42d2
commit 94744ba41a
7 changed files with 86 additions and 10 deletions

View File

@@ -8,7 +8,7 @@ import time
from concurrent.futures import Future
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, Optional
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import psutil
@@ -43,6 +43,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCore:
"""Inner loop of vLLM's Engine."""
@@ -280,6 +282,14 @@ class EngineCore:
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""

View File

@@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TypeVar, Union
import zmq
import zmq.asyncio
@@ -33,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCoreClient(ABC):
"""
@@ -117,6 +119,13 @@ class EngineCoreClient(ABC):
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError
@@ -153,6 +162,14 @@ class EngineCoreClient(ABC):
async def pin_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
class InprocClient(EngineCoreClient):
"""
@@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient):
def pin_lora(self, lora_id: int) -> bool:
return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngine:
"""One per data parallel rank."""
@@ -505,6 +529,14 @@ class SyncMPClient(MPClient):
def execute_dummy_batch(self) -> None:
self.call_utility("execute_dummy_batch")
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.call_utility("collective_rpc", method, timeout, args,
kwargs)
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
@@ -636,6 +668,15 @@ class AsyncMPClient(MPClient):
async def pin_lora_async(self, lora_id: int) -> bool:
return await self.call_utility_async("pin_lora", lora_id)
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return await self.call_utility_async("collective_rpc", method, timeout,
args, kwargs)
class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)

View File

@@ -2,7 +2,7 @@
from collections.abc import Mapping
from copy import copy
from typing import Optional, Union
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
@@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_R = TypeVar("_R", default=Any)
class LLMEngine:
@@ -282,6 +283,13 @@ class LLMEngine:
"""Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)