[docs] Add docs for new RL flows (#36188)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Aaron Hao
2026-03-18 02:04:26 -07:00
committed by GitHub
parent fad09e8a1f
commit 47a1f11bff
18 changed files with 514 additions and 760 deletions

View File

@@ -0,0 +1,415 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates async reinforcement learning using vLLM and Ray,
with native weight syncing APIs at engine instance.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies one GPU for training, whereas a
2x tensor-parallel vLLM inference engine occupies two GPUs.
The example performs the following steps:
* Load the training model on one gpu (scheduled via ray)
* Initialize the inference model with dummy weights across
two gpus using vLLM's tensor parallelism and Ray placement groups.
* Generate gibberish from a list of prompts using the randomly initialized
inference engine.
* Pause generation once generation completes for one sequence
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group.
* Resume generation and print out the results
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import asyncio
import uuid
from dataclasses import asdict
import ray
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import vllm
from vllm import SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD = 10
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
class MyLLM(vllm.AsyncLLMEngine):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, **kwargs):
engine_args = vllm.AsyncEngineArgs(**kwargs)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
super().__init__(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=engine_args.enable_log_requests,
log_stats=not engine_args.disable_log_stats,
)
self._generation_paused = False
self._request_pause_flag = False
async def do_generate(
self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
) -> tuple[vllm.RequestOutput, int]:
"""Generate a single request, setting the request pause flag once the
token count reaches the threshold.
Returns (output, pause_token_index). pause_token_index is the number
of tokens generated before the weight change, or -1 if no pause.
"""
pause_token_index = -1
prev_token_count = 0
async for request_output in self.generate(
{"prompt_token_ids": prompt_token_ids},
sampling_params,
request_id=str(uuid.uuid4()),
):
output = request_output
cur_token_count = len(output.outputs[0].token_ids)
if (
cur_token_count >= PAUSE_TOKEN_THRESHOLD
and not self._request_pause_flag
):
self._request_pause_flag = True
if self._generation_paused and pause_token_index == -1:
pause_token_index = prev_token_count
prev_token_count = cur_token_count
return output, pause_token_index
async def pause_after_n_tokens(self):
"""Wait for any request to set the pause flag, then pause."""
while not self._request_pause_flag:
await asyncio.sleep(0)
await super().pause_generation(mode="keep")
await asyncio.sleep(5)
self._generation_paused = True
@ray.remote(num_gpus=1)
class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU."""
def __init__(self, model_name: str):
from vllm.model_executor.layers.batch_invariant import (
init_batch_invariance,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
# need to init all env vars for batch invariance which affect nccl ops
attn_backend = (
AttentionBackendEnum.TRITON_ATTN
if current_platform.is_rocm()
else AttentionBackendEnum.FLASH_ATTN
)
init_batch_invariance(attn_backend)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16
).to("cuda:0")
self.port = get_open_port()
self.master_address = get_ip()
def get_master_address_and_port(self):
return self.master_address, self.port
def get_weight_metadata(self):
"""Return weight names, dtypes, and shapes for weight transfer."""
names = []
dtype_names = []
shapes = []
for name, p in self.model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
return names, dtype_names, shapes
def init_weight_transfer_group(self, world_size):
"""Initialize the NCCL process group for weight transfer."""
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=self.master_address,
master_port=self.port,
world_size=world_size,
),
)
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
trainer_args = NCCLTrainerSendWeightsArgs(
group=self.model_update_group,
packed=packed,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args=trainer_args,
)
@torch.inference_mode()
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
"""Greedy-decode max_new_tokens from the given context."""
input_ids = torch.tensor([token_ids], device="cuda:0")
output = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=False,
)
new_token_ids = output[0, len(token_ids) :].tolist()
return new_token_ids
# Build platform-specific env vars for Ray
ray_env_vars = {
# Prevent Ray from setting CUDA_VISIBLE_DEVICES
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
}
if current_platform.is_rocm():
# For ROCm, BATCH_INVARIANT vllm is not supported
ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
else:
# Enable batch invariance for deterministic outputs on NVIDIA
ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"
ray.init(runtime_env={"env_vars": ray_env_vars})
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME_V2)
rocm_determinism_kwargs = {}
if current_platform.is_rocm():
# ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
# sequential request processing (max_num_seqs=1).
rocm_determinism_kwargs = {
"seed": 0,
"enable_prefix_caching": False,
"max_num_seqs": 1,
}
# Build platform-specific LLM kwargs
llm_kwargs = dict(
model=MODEL_NAME_V1,
enforce_eager=True,
max_model_len=8192,
distributed_executor_backend="ray",
attention_backend=ATTN_BACKEND,
gpu_memory_utilization=0.75,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
llm_kwargs.update(rocm_determinism_kwargs)
# Launch the vLLM inference engine.
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
# from the internal DP resource check).
llm = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(**llm_kwargs)
PROMPTS = [
"The president of the United States is",
"The capital of France is",
"The largest ocean on Earth is",
"The speed of light in a vacuum is",
"The chemical formula for water is",
"The tallest mountain in the world is",
"The first person to walk on the moon was",
"The Great Wall of China was built to",
"Photosynthesis is the process by which",
"The theory of general relativity was proposed by",
"The boiling point of water at sea level is",
"The largest planet in our solar system is",
"DNA stands for deoxyribonucleic acid and it",
]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
batch_prompt_token_ids = [
tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
]
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
world_size = 2 # 1 trainer + 1 inference worker
inference_handle = llm.init_weight_transfer_engine.remote(
WeightTransferInitRequest(
init_info=asdict(
NCCLWeightTransferInitInfo(
master_address=master_address,
master_port=master_port,
rank_offset=1,
world_size=world_size,
)
)
)
)
# Initialize weight transfer group on both the training actor and inference engine
train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])
N_NEW_TOKENS = 100
# Collect weight metadata once
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# ── Phase 1: concurrent requests with weight sync ───────────────────
print(f"\n{'=' * 50}")
print(f"Prompts ({len(PROMPTS)}):")
for p in PROMPTS:
print(f" - {p!r}")
print(f"{'=' * 50}")
sampling_params = SamplingParams(
temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
)
gen_futures = [
llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
]
ray.get(llm.pause_after_n_tokens.remote())
inference_handle = llm.update_weights.remote(
WeightTransferUpdateRequest(
update_info=asdict(
NCCLWeightTransferUpdateInfo(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=True,
)
)
)
)
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
ray.get(llm.resume_generation.remote())
results = ray.get(gen_futures)
for i, (output, pause_idx) in enumerate(results):
all_token_ids = list(output.outputs[0].token_ids)
before_text = tokenizer.decode(all_token_ids[:pause_idx])
after_text = tokenizer.decode(all_token_ids[pause_idx:])
print(f"\n Request {i} ({PROMPTS[i]!r}):")
print(f" Old weights ({pause_idx} tokens): {before_text!r}")
n_after = len(all_token_ids) - pause_idx
print(f" New weights ({n_after} tokens): {after_text!r}")
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
# This validation relies on batch-invariant (deterministic) generation to
# compare outputs from the weight-synced engine against a fresh V2 instance.
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
# token match. On ROCm, batch invariance is not yet fully implemented
# (see https://github.com/vllm-project/vllm/issues/27433 and
# https://github.com/vllm-project/vllm/issues/33123), so residual
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
# can cause single-token divergences that don't indicate a weight-sync
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9
print(f"\n{'=' * 50}")
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
if current_platform.is_rocm():
print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
print(f"{'=' * 50}")
ray.get(llm.shutdown.remote())
ray.kill(llm)
ray.kill(train_model)
llm_v2_kwargs = dict(
model=MODEL_NAME_V2,
enforce_eager=True,
max_model_len=8192,
gpu_memory_utilization=0.75,
distributed_executor_backend="ray",
attention_backend=ATTN_BACKEND,
)
llm_v2_kwargs.update(rocm_determinism_kwargs)
llm_v2 = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(**llm_v2_kwargs)
val_futures = [
llm_v2.do_generate.remote(
list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
SamplingParams(
temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
),
)
for output, pause_idx in results
]
val_results = ray.get(val_futures)
num_pass = 0
num_total = len(results)
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
expected = list(output.outputs[0].token_ids)[pause_idx:]
actual = list(val_output.outputs[0].token_ids)
match = actual == expected
if match:
num_pass += 1
print(f" [PASS] {PROMPTS[i]!r}")
else:
print(f" [FAIL] {PROMPTS[i]!r}")
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
for j, (e, a) in enumerate(zip(expected, actual)):
if e != a:
print(
f" first divergence at output token {j}: "
f"expected {e} ({tokenizer.decode([e])!r}) vs "
f"actual {a} ({tokenizer.decode([a])!r})"
)
break
ray.get(llm_v2.shutdown.remote())
ray.kill(llm_v2)
pass_rate = num_pass / num_total
print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
print(f" Required: >= {MIN_PASS_RATE:.0%}")
assert pass_rate >= MIN_PASS_RATE, (
f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
f"is below the required {MIN_PASS_RATE:.0%} threshold. "
f"See failures above for details."
)
print("=" * 50)

View File

@@ -0,0 +1,181 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
via HTTP API, with IPC-based weight syncing APIs.
Unlike rlhf_nccl.py which uses NCCL and can use separate GPUs, this script
uses CUDA IPC which requires the training model and vLLM server to be on the
same GPU. Memory must be carefully managed to fit both models.
Unlike rlhf.py which creates a vLLM instance programmatically, this script
assumes you have already started a vLLM server using `vllm serve`. It uses:
- OpenAI-compatible API for inference requests
- HTTP endpoints for weight transfer control plane
- CUDA IPC for actual weight data transfer
Prerequisites:
Start a vLLM server with weight transfer enabled and reduced GPU memory
utilization to leave room for the training model:
$ VLLM_SERVER_DEV_MODE=1 VLLM_ALLOW_INSECURE_SERIALIZATION=1 \
vllm serve facebook/opt-125m --enforce-eager \
--weight-transfer-config '{"backend": "ipc"}' \
--load-format dummy \
--gpu-memory-utilization 0.5
Then run this script:
$ python rlhf_http_ipc.py
The example performs the following steps:
* Load the training model on GPU 0 (same GPU as the vLLM server).
* Generate text using the vLLM server via OpenAI-compatible API. The output
is expected to be nonsense because the server is initialized with dummy weights.
* Initialize weight transfer via HTTP endpoint (no-op for IPC).
* Broadcast the real weights from the training model to the vLLM server
using CUDA IPC handles.
* Generate text again to show normal output after the weight update.
"""
import os
import requests
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM
from vllm.distributed.weight_transfer.ipc_engine import (
IPCTrainerSendWeightsArgs,
IPCWeightTransferEngine,
)
BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
# Enable insecure serialization for IPC handle serialization
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
"""Generate completions using the OpenAI-compatible API."""
results = []
for prompt in prompts:
response = client.completions.create(
model=model,
prompt=prompt,
max_tokens=32,
temperature=0,
)
results.append(response.choices[0].text)
return results
def init_weight_transfer_engine(base_url: str) -> None:
"""Initialize weight transfer via HTTP endpoint (no-op for IPC)."""
url = f"{base_url}/init_weight_transfer_engine"
payload = {"init_info": dict()}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
response = requests.post(url, timeout=60)
response.raise_for_status()
def resume_generation(base_url: str) -> None:
"""Resume generation via HTTP endpoint."""
url = f"{base_url}/resume"
response = requests.post(url, timeout=60)
response.raise_for_status()
def get_world_size(base_url: str) -> int:
"""Get world size from the vLLM server."""
url = f"{base_url}/get_world_size"
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.json()["world_size"]
def main():
# IPC requires the training model to be on the same GPU as the vLLM server
# The server should be started on GPU 0 with reduced memory utilization
device = "cuda:0"
torch.accelerator.set_device_index(device)
# Load the training model on the same GPU as the server
# Use bfloat16 to reduce memory footprint
print(f"Loading training model: {MODEL_NAME} on {device}")
print(
"Note: Ensure the vLLM server was started with --gpu-memory-utilization 0.5 "
"or lower to leave room for the training model."
)
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
train_model.to(device)
train_model.eval() # Set to eval mode to save memory
# Create OpenAI client pointing to the vLLM server
client = OpenAI(
base_url=f"{BASE_URL}/v1",
api_key="EMPTY", # vLLM doesn't require an API key by default
)
# Test prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Generate text before weight update. The output is expected to be nonsense
# because the server is initialized with dummy weights.
print("-" * 50)
print("Generating text BEFORE weight update (expect nonsense):")
print("-" * 50)
outputs = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
print("Initializing weight transfer (IPC backend)...")
# Initialize weight transfer on vLLM server (no-op for IPC, but still required)
init_weight_transfer_engine(BASE_URL)
# Pause generation before weight sync
pause_generation(BASE_URL)
# Broadcast weights via IPC handles using HTTP mode
print("Broadcasting weights via CUDA IPC (HTTP)...")
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
IPCWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
trainer_args=trainer_args,
)
# Resume generation after weight sync
resume_generation(BASE_URL)
# Generate text after weight update. The output is expected to be normal
# because the real weights are now loaded.
print("-" * 50)
print("Generating text AFTER weight update:")
print("-" * 50)
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs_updated):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Note: The training model and IPC handles remain in memory.
# In a real RLHF training loop, you would update the training model
# and create new IPC handles for each weight update.
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
via HTTP API, with native weight syncing APIs.
Unlike rlhf.py which creates a vLLM instance programmatically, this script
assumes you have already started a vLLM server using `vllm serve`. It uses:
- OpenAI-compatible API for inference requests
- HTTP endpoints for weight transfer control plane
- NCCL for actual weight data transfer
Prerequisites:
Start a vLLM server with weight transfer enabled:
$ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
--enforce-eager \
--weight-transfer-config '{"backend": "nccl"}' \
--load-format dummy
Then run this script:
$ python rlhf_http.py
The example performs the following steps:
* Load the training model on GPU 0.
* Generate text using the vLLM server via OpenAI-compatible API. The output
is expected to be nonsense because the server is initialized with dummy weights.
* Initialize weight transfer via HTTP endpoint.
* Broadcast the real weights from the training model to the vLLM server
using NCCL.
* Generate text again to show normal output after the weight update.
"""
import requests
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
"""Generate completions using the OpenAI-compatible API."""
results = []
for prompt in prompts:
response = client.completions.create(
model=model,
prompt=prompt,
max_tokens=32,
temperature=0,
)
results.append(response.choices[0].text)
return results
def init_weight_transfer_engine(
base_url: str,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
) -> None:
"""Initialize weight transfer via HTTP endpoint."""
url = f"{base_url}/init_weight_transfer_engine"
payload = {
"init_info": dict(
master_address=master_address,
master_port=master_port,
rank_offset=rank_offset,
world_size=world_size,
)
}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def update_weights(
base_url: str,
names: list[str],
dtype_names: list[str],
shapes: list[list[int]],
packed: bool = False,
) -> None:
"""Update weights via HTTP endpoint."""
url = f"{base_url}/update_weights"
payload = {
"update_info": dict(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=packed,
)
}
response = requests.post(url, json=payload, timeout=300)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
response = requests.post(url, timeout=60)
response.raise_for_status()
def resume_generation(base_url: str) -> None:
"""Resume generation via HTTP endpoint."""
url = f"{base_url}/resume"
response = requests.post(url, timeout=60)
response.raise_for_status()
def get_world_size(base_url: str) -> int:
"""Get world size from the vLLM server."""
url = f"{base_url}/get_world_size"
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.json()["world_size"]
def main():
# Get the inference world size from the vLLM server
inference_world_size = get_world_size(BASE_URL)
world_size = inference_world_size + 1 # +1 for the trainer
device = f"cuda:{inference_world_size}"
torch.accelerator.set_device_index(device)
# Load the training model
print(f"Loading training model: {MODEL_NAME}")
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
train_model.to(device)
# Create OpenAI client pointing to the vLLM server
client = OpenAI(
base_url=f"{BASE_URL}/v1",
api_key="EMPTY", # vLLM doesn't require an API key by default
)
# Test prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Generate text before weight update. The output is expected to be nonsense
# because the server is initialized with dummy weights.
print("-" * 50)
print("Generating text BEFORE weight update (expect nonsense):")
print("-" * 50)
outputs = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Set up the communication channel between the training process and the
# vLLM server. The trainer is rank 0, vLLM worker(s) start at rank_offset.
master_address = get_ip()
master_port = get_open_port()
rank_offset = 1
print(f"Initializing weight transfer: master={master_address}:{master_port}")
# Initialize weight transfer on vLLM server (this is async, server will
# wait for NCCL connection)
import threading
init_thread = threading.Thread(
target=init_weight_transfer_engine,
args=(BASE_URL, master_address, master_port, rank_offset, world_size),
)
init_thread.start()
# Initialize NCCL process group on trainer side
model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=master_address,
master_port=master_port,
world_size=world_size,
),
)
# Wait for init_weight_transfer_engine to complete
init_thread.join()
# Pause generation before weight sync
pause_generation(BASE_URL)
# Collect weight metadata for the update request
names = []
dtype_names = []
shapes = []
for name, p in train_model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
# Start the update_weights call in a separate thread since it will block
# waiting for NCCL broadcasts
# packed=True enables efficient batched tensor broadcasting
update_thread = threading.Thread(
target=update_weights,
args=(BASE_URL, names, dtype_names, shapes, True), # packed=True
)
update_thread.start()
# Broadcast all weights from trainer to vLLM workers
print("Broadcasting weights via NCCL...")
trainer_args = NCCLTrainerSendWeightsArgs(
group=model_update_group,
packed=True,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
trainer_args=trainer_args,
)
# Wait for update_weights to complete
update_thread.join()
# Resume generation after weight sync
resume_generation(BASE_URL)
# Generate text after weight update. The output is expected to be normal
# because the real weights are now loaded.
print("-" * 50)
print("Generating text AFTER weight update:")
print("-" * 50)
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs_updated):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
main()

149
examples/rl/rlhf_ipc.py Normal file
View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray,
with IPC-based weight syncing APIs
The script colocates the training and inference workloads onto the same GPU using Ray.
The example performs the following steps:
* Request a placement group of 1 GPU.
* Place the inference model on the above GPU using the placement group.
* Place and load the training model on the same GPU using the placement group.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using CUDA IPC handles. Note that
for demonstration purposes we simply zero out the weights.
This example assumes a single-node cluster with a single GPU,
but can be extended to multiple GPUs.
"""
import os
import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.ipc_engine import (
IPCTrainerSendWeightsArgs,
IPCWeightTransferEngine,
)
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
# Each worker uses 0.4 GPU so that two instances fit on the same GPU.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
# needed for ipc handle serialization
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
super().__init__(*args, **kwargs)
# Load the OPT-125M model onto GPU 0 for the training workload.
MODEL_NAME = "facebook/opt-125m"
@ray.remote
class TrainModel:
def __init__(self, llm_handle: ray.actor.ActorHandle):
self.train_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
)
self.train_model.to("cuda:0")
self.llm_handle = llm_handle
def init_weight_transfer(self):
# IPC backend doesn't need initialization info
ray.get(
self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict()))
)
def broadcast_weights(self, llm_handle: ray.actor.ActorHandle):
"""Broadcast weights to the inference engine using IPC."""
self.llm_handle = llm_handle
trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
IPCWeightTransferEngine.trainer_send_weights(
iterator=self.train_model.named_parameters(),
trainer_args=trainer_args,
)
ray.init()
pg_colocate = placement_group([{"GPU": 1, "CPU": 0}])
ray.get(pg_colocate.ready())
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg_colocate,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model=MODEL_NAME,
enforce_eager=True,
tensor_parallel_size=1,
distributed_executor_backend="ray",
gpu_memory_utilization=0.7,
weight_transfer_config=WeightTransferConfig(backend="ipc"),
load_format="dummy",
)
train_model = TrainModel.options(
num_gpus=0.1,
num_cpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg_colocate, placement_group_capture_child_tasks=True
),
).remote(llm)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
ray.get(llm.sleep.remote(level=0))
ray.get(train_model.init_weight_transfer.remote())
# Synchronize the updated weights to the inference engine using batched API.
ray.get(train_model.broadcast_weights.remote(llm))
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

216
examples/rl/rlhf_nccl.py Normal file
View File

@@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning using vLLM and Ray,
with native weight syncing APIs at engine instance.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies one GPU for training, whereas a
2x tensor-parallel vLLM inference engine occupies two GPUs.
The example performs the following steps:
* Load the training model on one gpu (scheduled via ray)
* Initialize the inference model with dummy weights across
two gpus using vLLM's tensor parallelism and Ray placement groups.
* Generate gibberish from a list of prompts using the randomly initialized
inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group.
* Generating from the list of prompts after weight sync should result
in sensible outputs.
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import os
import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
MODEL_NAME = "facebook/opt-125m"
# MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128"
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1"
super().__init__(*args, **kwargs)
@ray.remote(num_gpus=1)
class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU."""
def __init__(self, model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
).to("cuda:0")
self.port = get_open_port()
self.master_address = get_ip()
def get_master_address_and_port(self):
return self.master_address, self.port
def get_weight_metadata(self):
"""Return weight names, dtypes, and shapes for weight transfer."""
names = []
dtype_names = []
shapes = []
for name, p in self.model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
return names, dtype_names, shapes
def init_weight_transfer_group(self, world_size):
"""Initialize the NCCL process group for weight transfer."""
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=self.master_address,
master_port=self.port,
world_size=world_size,
),
)
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
trainer_args = NCCLTrainerSendWeightsArgs(
group=self.model_update_group,
packed=packed,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args=trainer_args,
)
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
ray.init()
# Create a placement group that reserves GPU 12 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME)
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)
# are now native to vLLM workers.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model=MODEL_NAME,
enforce_eager=True,
tensor_parallel_size=2,
data_parallel_size=1,
distributed_executor_backend="ray",
weight_transfer_config=WeightTransferConfig(backend="nccl"),
load_format="dummy",
quantization="fp8",
)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
# Generate text with the initial model. The output is expected to be nonsense
# because the weights are randomly initialized.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
ray.get(llm.sleep.remote(level=0))
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer
inference_handle = llm.init_weight_transfer_engine.remote(
dict(
init_info=dict(
master_address=master_address,
master_port=master_port,
rank_offset=1,
world_size=world_size,
)
)
)
# Initialize weight transfer group on both the training actor and inference engine
train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])
# Synchronize the updated weights to the inference engine using batched API.
# Collect all weight metadata from the training actor
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
inference_handle = llm.update_weights.remote(
dict(
update_info=dict(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=True,
)
)
)
# Broadcast all weights from trainer using the weight transfer API
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model. The output is expected to be normal
# because the weights are updated.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)