[docs] Add docs for new RL flows (#36188)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
162
docs/training/weight_transfer/base.md
Normal file
162
docs/training/weight_transfer/base.md
Normal file
@@ -0,0 +1,162 @@
|
||||
# Base Class and Custom Engines
|
||||
|
||||
The weight transfer system is built on an abstract base class that defines the contract between vLLM's worker infrastructure and the transport backend. You can implement custom backends by subclassing `WeightTransferEngine` and registering them with the `WeightTransferEngineFactory`.
|
||||
|
||||
## WeightTransferEngine
|
||||
|
||||
The `WeightTransferEngine` is a generic abstract class parameterized by two dataclass types:
|
||||
|
||||
- **`TInitInfo`** (extends `WeightTransferInitInfo`): Backend-specific initialization parameters.
|
||||
- **`TUpdateInfo`** (extends `WeightTransferUpdateInfo`): Backend-specific weight update metadata.
|
||||
|
||||
### Abstract Methods
|
||||
|
||||
Subclasses must implement these four methods:
|
||||
|
||||
| Method | Side | Description |
|
||||
| ------ | ---- | ----------- |
|
||||
| `init_transfer_engine(init_info)` | Inference | Initialize the communication channel on each inference worker |
|
||||
| `receive_weights(update_info, load_weights)` | Inference | Receive weights and call `load_weights` incrementally |
|
||||
| `shutdown()` | Inference | Clean up resources |
|
||||
| `trainer_send_weights(iterator, trainer_args)` | Trainer | Static method to send weights from the trainer process |
|
||||
|
||||
### Request Classes
|
||||
|
||||
The API-level request classes provide backend-agnostic serialization using plain dictionaries. The engine's `parse_init_info` and `parse_update_info` methods convert these dictionaries into typed dataclasses.
|
||||
|
||||
```python
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferInitRequest,
|
||||
WeightTransferUpdateRequest,
|
||||
)
|
||||
|
||||
# Init request (dict is converted to backend-specific TInitInfo)
|
||||
init_request = WeightTransferInitRequest(
|
||||
init_info={"master_address": "10.0.0.1", "master_port": 29500, ...}
|
||||
)
|
||||
|
||||
# Update request (dict is converted to backend-specific TUpdateInfo)
|
||||
update_request = WeightTransferUpdateRequest(
|
||||
update_info={"names": [...], "dtype_names": [...], "shapes": [...]}
|
||||
)
|
||||
```
|
||||
|
||||
### WeightTransferUpdateInfo
|
||||
|
||||
The base `WeightTransferUpdateInfo` includes an `is_checkpoint_format` flag:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class WeightTransferUpdateInfo(ABC):
|
||||
is_checkpoint_format: bool = True
|
||||
```
|
||||
|
||||
When `is_checkpoint_format=True` (the default), vLLM applies layerwise weight processing (repacking, renaming, etc.) on the received weights before loading them. Set to `False` if the trainer has already converted weights to the kernel format expected by the model.
|
||||
|
||||
## Implementing a Custom Engine
|
||||
|
||||
To create a custom weight transfer backend:
|
||||
|
||||
### 1. Define Info Dataclasses
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferEngine,
|
||||
WeightTransferInitInfo,
|
||||
WeightTransferUpdateInfo,
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class MyInitInfo(WeightTransferInitInfo):
|
||||
endpoint: str
|
||||
token: str
|
||||
|
||||
@dataclass
|
||||
class MyUpdateInfo(WeightTransferUpdateInfo):
|
||||
names: list[str]
|
||||
dtype_names: list[str]
|
||||
shapes: list[list[int]]
|
||||
# Add custom fields as needed
|
||||
```
|
||||
|
||||
### 2. Implement the Engine
|
||||
|
||||
```python
|
||||
from collections.abc import Callable, Iterator
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
class MyWeightTransferEngine(WeightTransferEngine[MyInitInfo, MyUpdateInfo]):
|
||||
init_info_cls = MyInitInfo
|
||||
update_info_cls = MyUpdateInfo
|
||||
|
||||
def init_transfer_engine(self, init_info: MyInitInfo) -> None:
|
||||
# Set up connection to trainer using init_info.endpoint, etc.
|
||||
...
|
||||
|
||||
def receive_weights(
|
||||
self,
|
||||
update_info: MyUpdateInfo,
|
||||
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
|
||||
) -> None:
|
||||
# Receive each weight and call load_weights incrementally
|
||||
for name, dtype_name, shape in zip(
|
||||
update_info.names, update_info.dtype_names, update_info.shapes
|
||||
):
|
||||
dtype = getattr(torch, dtype_name)
|
||||
weight = self._fetch_weight(name, shape, dtype)
|
||||
load_weights([(name, weight)])
|
||||
|
||||
def shutdown(self) -> None:
|
||||
# Clean up resources
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
trainer_args: dict[str, Any],
|
||||
) -> None:
|
||||
# Send weights from the trainer process
|
||||
for name, tensor in iterator:
|
||||
# Send tensor via custom transport
|
||||
...
|
||||
```
|
||||
|
||||
!!! important
|
||||
The `load_weights` callable passed to `receive_weights` should be called **incrementally** (one or a few weights at a time) rather than accumulating all weights first. This avoids GPU out-of-memory errors with large models.
|
||||
|
||||
### 3. Register with the Factory
|
||||
|
||||
```python
|
||||
from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory
|
||||
|
||||
# Option 1: Lazy loading (recommended for built-in engines)
|
||||
WeightTransferEngineFactory.register_engine(
|
||||
"my_backend",
|
||||
"my_package.my_module",
|
||||
"MyWeightTransferEngine",
|
||||
)
|
||||
|
||||
# Option 2: Direct class registration
|
||||
WeightTransferEngineFactory.register_engine(
|
||||
"my_backend",
|
||||
MyWeightTransferEngine,
|
||||
)
|
||||
```
|
||||
|
||||
Once registered, users can select your backend via `WeightTransferConfig(backend="my_backend")`.
|
||||
|
||||
## WeightTransferEngineFactory
|
||||
|
||||
The factory uses a registry pattern with lazy loading. Built-in engines (`nccl` and `ipc`) are registered at import time but their modules are only loaded when the backend is actually requested. This avoids importing heavy dependencies (like NCCL communicators) when they aren't needed.
|
||||
|
||||
```python
|
||||
from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory
|
||||
|
||||
# Create an engine from config
|
||||
engine = WeightTransferEngineFactory.create_engine(
|
||||
config=weight_transfer_config,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
```
|
||||
Reference in New Issue
Block a user