[Feat][RL][1/2] Native Weight Syncing API: NCCL (#31943)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Co-authored-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
@@ -6,6 +6,10 @@ from collections.abc import AsyncGenerator, Iterable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferInitRequest,
|
||||
WeightTransferUpdateRequest,
|
||||
)
|
||||
from vllm.inputs.data import PromptType, StreamingInput
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
@@ -191,3 +195,13 @@ class EngineClient(ABC):
|
||||
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
"""Get supported tasks"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def init_weight_transfer_engine(
|
||||
self, init_request: WeightTransferInitRequest
|
||||
) -> None:
|
||||
"""Initialize weight transfer for RL training."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
|
||||
"""Batched weight update for RL training."""
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user