[docs] Add docs for new RL flows (#36188)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
415
examples/rl/rlhf_async_new_apis.py
Normal file
415
examples/rl/rlhf_async_new_apis.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates async reinforcement learning using vLLM and Ray,
|
||||
with native weight syncing APIs at engine instance.
|
||||
|
||||
The script separates training and inference workloads onto distinct GPUs
|
||||
so that Ray can manage process placement and inter-process communication.
|
||||
A Hugging Face Transformer model occupies one GPU for training, whereas a
|
||||
2x tensor-parallel vLLM inference engine occupies two GPUs.
|
||||
|
||||
The example performs the following steps:
|
||||
* Load the training model on one gpu (scheduled via ray)
|
||||
* Initialize the inference model with dummy weights across
|
||||
two gpus using vLLM's tensor parallelism and Ray placement groups.
|
||||
* Generate gibberish from a list of prompts using the randomly initialized
|
||||
inference engine.
|
||||
* Pause generation once generation completes for one sequence
|
||||
* Update the weights of the training model and broadcast the updated weights
|
||||
to the inference engine by using a Ray collective RPC group.
|
||||
* Resume generation and print out the results
|
||||
|
||||
This example assumes a single-node cluster with three GPUs, but Ray
|
||||
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
|
||||
workloads. Residual GPU activity interferes with vLLM memory profiling and
|
||||
causes unexpected behavior.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferInitRequest,
|
||||
WeightTransferUpdateRequest,
|
||||
)
|
||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
NCCLTrainerSendWeightsArgs,
|
||||
NCCLWeightTransferEngine,
|
||||
NCCLWeightTransferInitInfo,
|
||||
NCCLWeightTransferUpdateInfo,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
from vllm.v1.executor import Executor
|
||||
|
||||
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
|
||||
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
|
||||
PAUSE_TOKEN_THRESHOLD = 10
|
||||
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
|
||||
|
||||
|
||||
class MyLLM(vllm.AsyncLLMEngine):
|
||||
"""Configure the vLLM worker for Ray placement group execution."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
engine_args = vllm.AsyncEngineArgs(**kwargs)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_requests=engine_args.enable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
self._generation_paused = False
|
||||
self._request_pause_flag = False
|
||||
|
||||
async def do_generate(
|
||||
self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
|
||||
) -> tuple[vllm.RequestOutput, int]:
|
||||
"""Generate a single request, setting the request pause flag once the
|
||||
token count reaches the threshold.
|
||||
|
||||
Returns (output, pause_token_index). pause_token_index is the number
|
||||
of tokens generated before the weight change, or -1 if no pause.
|
||||
"""
|
||||
pause_token_index = -1
|
||||
prev_token_count = 0
|
||||
async for request_output in self.generate(
|
||||
{"prompt_token_ids": prompt_token_ids},
|
||||
sampling_params,
|
||||
request_id=str(uuid.uuid4()),
|
||||
):
|
||||
output = request_output
|
||||
cur_token_count = len(output.outputs[0].token_ids)
|
||||
if (
|
||||
cur_token_count >= PAUSE_TOKEN_THRESHOLD
|
||||
and not self._request_pause_flag
|
||||
):
|
||||
self._request_pause_flag = True
|
||||
if self._generation_paused and pause_token_index == -1:
|
||||
pause_token_index = prev_token_count
|
||||
prev_token_count = cur_token_count
|
||||
return output, pause_token_index
|
||||
|
||||
async def pause_after_n_tokens(self):
|
||||
"""Wait for any request to set the pause flag, then pause."""
|
||||
while not self._request_pause_flag:
|
||||
await asyncio.sleep(0)
|
||||
await super().pause_generation(mode="keep")
|
||||
await asyncio.sleep(5)
|
||||
self._generation_paused = True
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class TrainModel:
|
||||
"""Ray actor that wraps the training model on a dedicated GPU."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
init_batch_invariance,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
# need to init all env vars for batch invariance which affect nccl ops
|
||||
attn_backend = (
|
||||
AttentionBackendEnum.TRITON_ATTN
|
||||
if current_platform.is_rocm()
|
||||
else AttentionBackendEnum.FLASH_ATTN
|
||||
)
|
||||
init_batch_invariance(attn_backend)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, dtype=torch.bfloat16
|
||||
).to("cuda:0")
|
||||
self.port = get_open_port()
|
||||
self.master_address = get_ip()
|
||||
|
||||
def get_master_address_and_port(self):
|
||||
return self.master_address, self.port
|
||||
|
||||
def get_weight_metadata(self):
|
||||
"""Return weight names, dtypes, and shapes for weight transfer."""
|
||||
names = []
|
||||
dtype_names = []
|
||||
shapes = []
|
||||
for name, p in self.model.named_parameters():
|
||||
names.append(name)
|
||||
dtype_names.append(str(p.dtype).split(".")[-1])
|
||||
shapes.append(list(p.shape))
|
||||
return names, dtype_names, shapes
|
||||
|
||||
def init_weight_transfer_group(self, world_size):
|
||||
"""Initialize the NCCL process group for weight transfer."""
|
||||
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
|
||||
dict(
|
||||
master_address=self.master_address,
|
||||
master_port=self.port,
|
||||
world_size=world_size,
|
||||
),
|
||||
)
|
||||
|
||||
def broadcast_weights(self, packed: bool = True):
|
||||
"""Broadcast weights to the inference engine."""
|
||||
trainer_args = NCCLTrainerSendWeightsArgs(
|
||||
group=self.model_update_group,
|
||||
packed=packed,
|
||||
)
|
||||
NCCLWeightTransferEngine.trainer_send_weights(
|
||||
iterator=self.model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
|
||||
"""Greedy-decode max_new_tokens from the given context."""
|
||||
input_ids = torch.tensor([token_ids], device="cuda:0")
|
||||
output = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
)
|
||||
new_token_ids = output[0, len(token_ids) :].tolist()
|
||||
return new_token_ids
|
||||
|
||||
|
||||
# Build platform-specific env vars for Ray
|
||||
ray_env_vars = {
|
||||
# Prevent Ray from setting CUDA_VISIBLE_DEVICES
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
|
||||
}
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# For ROCm, BATCH_INVARIANT vllm is not supported
|
||||
ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
|
||||
else:
|
||||
# Enable batch invariance for deterministic outputs on NVIDIA
|
||||
ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"
|
||||
|
||||
ray.init(runtime_env={"env_vars": ray_env_vars})
|
||||
|
||||
# Launch the training model actor. Ray's resource scheduler will allocate
|
||||
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
|
||||
train_model = TrainModel.remote(MODEL_NAME_V2)
|
||||
|
||||
rocm_determinism_kwargs = {}
|
||||
if current_platform.is_rocm():
|
||||
# ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
|
||||
# sequential request processing (max_num_seqs=1).
|
||||
rocm_determinism_kwargs = {
|
||||
"seed": 0,
|
||||
"enable_prefix_caching": False,
|
||||
"max_num_seqs": 1,
|
||||
}
|
||||
|
||||
# Build platform-specific LLM kwargs
|
||||
llm_kwargs = dict(
|
||||
model=MODEL_NAME_V1,
|
||||
enforce_eager=True,
|
||||
max_model_len=8192,
|
||||
distributed_executor_backend="ray",
|
||||
attention_backend=ATTN_BACKEND,
|
||||
gpu_memory_utilization=0.75,
|
||||
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
||||
)
|
||||
llm_kwargs.update(rocm_determinism_kwargs)
|
||||
|
||||
# Launch the vLLM inference engine.
|
||||
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
|
||||
# its own placement groups internally for each DP rank, so we must NOT
|
||||
# create an outer placement group (it would reserve GPUs and hide them
|
||||
# from the internal DP resource check).
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
)(MyLLM).remote(**llm_kwargs)
|
||||
|
||||
PROMPTS = [
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The largest ocean on Earth is",
|
||||
"The speed of light in a vacuum is",
|
||||
"The chemical formula for water is",
|
||||
"The tallest mountain in the world is",
|
||||
"The first person to walk on the moon was",
|
||||
"The Great Wall of China was built to",
|
||||
"Photosynthesis is the process by which",
|
||||
"The theory of general relativity was proposed by",
|
||||
"The boiling point of water at sea level is",
|
||||
"The largest planet in our solar system is",
|
||||
"DNA stands for deoxyribonucleic acid and it",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
|
||||
batch_prompt_token_ids = [
|
||||
tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
|
||||
]
|
||||
|
||||
|
||||
# Set up the communication channel between the training process and the
|
||||
# inference engine.
|
||||
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
|
||||
|
||||
world_size = 2 # 1 trainer + 1 inference worker
|
||||
inference_handle = llm.init_weight_transfer_engine.remote(
|
||||
WeightTransferInitRequest(
|
||||
init_info=asdict(
|
||||
NCCLWeightTransferInitInfo(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
rank_offset=1,
|
||||
world_size=world_size,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize weight transfer group on both the training actor and inference engine
|
||||
train_handle = train_model.init_weight_transfer_group.remote(world_size)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
|
||||
N_NEW_TOKENS = 100
|
||||
|
||||
# Collect weight metadata once
|
||||
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
|
||||
|
||||
# ── Phase 1: concurrent requests with weight sync ───────────────────
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f"Prompts ({len(PROMPTS)}):")
|
||||
for p in PROMPTS:
|
||||
print(f" - {p!r}")
|
||||
print(f"{'=' * 50}")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
|
||||
)
|
||||
|
||||
gen_futures = [
|
||||
llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
|
||||
]
|
||||
|
||||
ray.get(llm.pause_after_n_tokens.remote())
|
||||
|
||||
inference_handle = llm.update_weights.remote(
|
||||
WeightTransferUpdateRequest(
|
||||
update_info=asdict(
|
||||
NCCLWeightTransferUpdateInfo(
|
||||
names=names,
|
||||
dtype_names=dtype_names,
|
||||
shapes=shapes,
|
||||
packed=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
train_handle = train_model.broadcast_weights.remote(packed=True)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
ray.get(llm.resume_generation.remote())
|
||||
results = ray.get(gen_futures)
|
||||
|
||||
for i, (output, pause_idx) in enumerate(results):
|
||||
all_token_ids = list(output.outputs[0].token_ids)
|
||||
before_text = tokenizer.decode(all_token_ids[:pause_idx])
|
||||
after_text = tokenizer.decode(all_token_ids[pause_idx:])
|
||||
print(f"\n Request {i} ({PROMPTS[i]!r}):")
|
||||
print(f" Old weights ({pause_idx} tokens): {before_text!r}")
|
||||
n_after = len(all_token_ids) - pause_idx
|
||||
print(f" New weights ({n_after} tokens): {after_text!r}")
|
||||
|
||||
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
|
||||
# This validation relies on batch-invariant (deterministic) generation to
|
||||
# compare outputs from the weight-synced engine against a fresh V2 instance.
|
||||
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
|
||||
# token match. On ROCm, batch invariance is not yet fully implemented
|
||||
# (see https://github.com/vllm-project/vllm/issues/27433 and
|
||||
# https://github.com/vllm-project/vllm/issues/33123), so residual
|
||||
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
|
||||
# can cause single-token divergences that don't indicate a weight-sync
|
||||
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
|
||||
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
|
||||
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9
|
||||
|
||||
print(f"\n{'=' * 50}")
|
||||
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
|
||||
if current_platform.is_rocm():
|
||||
print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
|
||||
print(f"{'=' * 50}")
|
||||
|
||||
ray.get(llm.shutdown.remote())
|
||||
ray.kill(llm)
|
||||
ray.kill(train_model)
|
||||
|
||||
llm_v2_kwargs = dict(
|
||||
model=MODEL_NAME_V2,
|
||||
enforce_eager=True,
|
||||
max_model_len=8192,
|
||||
gpu_memory_utilization=0.75,
|
||||
distributed_executor_backend="ray",
|
||||
attention_backend=ATTN_BACKEND,
|
||||
)
|
||||
llm_v2_kwargs.update(rocm_determinism_kwargs)
|
||||
|
||||
llm_v2 = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
)(MyLLM).remote(**llm_v2_kwargs)
|
||||
|
||||
val_futures = [
|
||||
llm_v2.do_generate.remote(
|
||||
list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
|
||||
SamplingParams(
|
||||
temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
|
||||
),
|
||||
)
|
||||
for output, pause_idx in results
|
||||
]
|
||||
val_results = ray.get(val_futures)
|
||||
|
||||
num_pass = 0
|
||||
num_total = len(results)
|
||||
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
|
||||
expected = list(output.outputs[0].token_ids)[pause_idx:]
|
||||
actual = list(val_output.outputs[0].token_ids)
|
||||
match = actual == expected
|
||||
|
||||
if match:
|
||||
num_pass += 1
|
||||
print(f" [PASS] {PROMPTS[i]!r}")
|
||||
else:
|
||||
print(f" [FAIL] {PROMPTS[i]!r}")
|
||||
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
|
||||
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
|
||||
for j, (e, a) in enumerate(zip(expected, actual)):
|
||||
if e != a:
|
||||
print(
|
||||
f" first divergence at output token {j}: "
|
||||
f"expected {e} ({tokenizer.decode([e])!r}) vs "
|
||||
f"actual {a} ({tokenizer.decode([a])!r})"
|
||||
)
|
||||
break
|
||||
|
||||
ray.get(llm_v2.shutdown.remote())
|
||||
ray.kill(llm_v2)
|
||||
|
||||
pass_rate = num_pass / num_total
|
||||
print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
|
||||
print(f" Required: >= {MIN_PASS_RATE:.0%}")
|
||||
|
||||
assert pass_rate >= MIN_PASS_RATE, (
|
||||
f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
|
||||
f"is below the required {MIN_PASS_RATE:.0%} threshold. "
|
||||
f"See failures above for details."
|
||||
)
|
||||
print("=" * 50)
|
||||
181
examples/rl/rlhf_http_ipc.py
Normal file
181
examples/rl/rlhf_http_ipc.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
|
||||
via HTTP API, with IPC-based weight syncing APIs.
|
||||
|
||||
Unlike rlhf_nccl.py which uses NCCL and can use separate GPUs, this script
|
||||
uses CUDA IPC which requires the training model and vLLM server to be on the
|
||||
same GPU. Memory must be carefully managed to fit both models.
|
||||
|
||||
Unlike rlhf.py which creates a vLLM instance programmatically, this script
|
||||
assumes you have already started a vLLM server using `vllm serve`. It uses:
|
||||
- OpenAI-compatible API for inference requests
|
||||
- HTTP endpoints for weight transfer control plane
|
||||
- CUDA IPC for actual weight data transfer
|
||||
|
||||
Prerequisites:
|
||||
Start a vLLM server with weight transfer enabled and reduced GPU memory
|
||||
utilization to leave room for the training model:
|
||||
|
||||
$ VLLM_SERVER_DEV_MODE=1 VLLM_ALLOW_INSECURE_SERIALIZATION=1 \
|
||||
vllm serve facebook/opt-125m --enforce-eager \
|
||||
--weight-transfer-config '{"backend": "ipc"}' \
|
||||
--load-format dummy \
|
||||
--gpu-memory-utilization 0.5
|
||||
|
||||
Then run this script:
|
||||
|
||||
$ python rlhf_http_ipc.py
|
||||
|
||||
The example performs the following steps:
|
||||
|
||||
* Load the training model on GPU 0 (same GPU as the vLLM server).
|
||||
* Generate text using the vLLM server via OpenAI-compatible API. The output
|
||||
is expected to be nonsense because the server is initialized with dummy weights.
|
||||
* Initialize weight transfer via HTTP endpoint (no-op for IPC).
|
||||
* Broadcast the real weights from the training model to the vLLM server
|
||||
using CUDA IPC handles.
|
||||
* Generate text again to show normal output after the weight update.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from openai import OpenAI
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
IPCTrainerSendWeightsArgs,
|
||||
IPCWeightTransferEngine,
|
||||
)
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
# Enable insecure serialization for IPC handle serialization
|
||||
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
|
||||
|
||||
|
||||
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
|
||||
"""Generate completions using the OpenAI-compatible API."""
|
||||
results = []
|
||||
for prompt in prompts:
|
||||
response = client.completions.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=32,
|
||||
temperature=0,
|
||||
)
|
||||
results.append(response.choices[0].text)
|
||||
return results
|
||||
|
||||
|
||||
def init_weight_transfer_engine(base_url: str) -> None:
|
||||
"""Initialize weight transfer via HTTP endpoint (no-op for IPC)."""
|
||||
url = f"{base_url}/init_weight_transfer_engine"
|
||||
payload = {"init_info": dict()}
|
||||
response = requests.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def pause_generation(base_url: str) -> None:
|
||||
"""Pause generation via HTTP endpoint."""
|
||||
url = f"{base_url}/pause"
|
||||
response = requests.post(url, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def resume_generation(base_url: str) -> None:
|
||||
"""Resume generation via HTTP endpoint."""
|
||||
url = f"{base_url}/resume"
|
||||
response = requests.post(url, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def get_world_size(base_url: str) -> int:
|
||||
"""Get world size from the vLLM server."""
|
||||
url = f"{base_url}/get_world_size"
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.json()["world_size"]
|
||||
|
||||
|
||||
def main():
|
||||
# IPC requires the training model to be on the same GPU as the vLLM server
|
||||
# The server should be started on GPU 0 with reduced memory utilization
|
||||
device = "cuda:0"
|
||||
torch.accelerator.set_device_index(device)
|
||||
|
||||
# Load the training model on the same GPU as the server
|
||||
# Use bfloat16 to reduce memory footprint
|
||||
print(f"Loading training model: {MODEL_NAME} on {device}")
|
||||
print(
|
||||
"Note: Ensure the vLLM server was started with --gpu-memory-utilization 0.5 "
|
||||
"or lower to leave room for the training model."
|
||||
)
|
||||
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
|
||||
train_model.to(device)
|
||||
train_model.eval() # Set to eval mode to save memory
|
||||
|
||||
# Create OpenAI client pointing to the vLLM server
|
||||
client = OpenAI(
|
||||
base_url=f"{BASE_URL}/v1",
|
||||
api_key="EMPTY", # vLLM doesn't require an API key by default
|
||||
)
|
||||
|
||||
# Test prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Generate text before weight update. The output is expected to be nonsense
|
||||
# because the server is initialized with dummy weights.
|
||||
print("-" * 50)
|
||||
print("Generating text BEFORE weight update (expect nonsense):")
|
||||
print("-" * 50)
|
||||
outputs = generate_completions(client, MODEL_NAME, prompts)
|
||||
for prompt, generated_text in zip(prompts, outputs):
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
print("Initializing weight transfer (IPC backend)...")
|
||||
|
||||
# Initialize weight transfer on vLLM server (no-op for IPC, but still required)
|
||||
init_weight_transfer_engine(BASE_URL)
|
||||
|
||||
# Pause generation before weight sync
|
||||
pause_generation(BASE_URL)
|
||||
|
||||
# Broadcast weights via IPC handles using HTTP mode
|
||||
print("Broadcasting weights via CUDA IPC (HTTP)...")
|
||||
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=train_model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
# Resume generation after weight sync
|
||||
resume_generation(BASE_URL)
|
||||
|
||||
# Generate text after weight update. The output is expected to be normal
|
||||
# because the real weights are now loaded.
|
||||
print("-" * 50)
|
||||
print("Generating text AFTER weight update:")
|
||||
print("-" * 50)
|
||||
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
|
||||
for prompt, generated_text in zip(prompts, outputs_updated):
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Note: The training model and IPC handles remain in memory.
|
||||
# In a real RLHF training loop, you would update the training model
|
||||
# and create new IPC handles for each weight update.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
245
examples/rl/rlhf_http_nccl.py
Normal file
245
examples/rl/rlhf_http_nccl.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
|
||||
via HTTP API, with native weight syncing APIs.
|
||||
|
||||
Unlike rlhf.py which creates a vLLM instance programmatically, this script
|
||||
assumes you have already started a vLLM server using `vllm serve`. It uses:
|
||||
- OpenAI-compatible API for inference requests
|
||||
- HTTP endpoints for weight transfer control plane
|
||||
- NCCL for actual weight data transfer
|
||||
|
||||
Prerequisites:
|
||||
Start a vLLM server with weight transfer enabled:
|
||||
|
||||
$ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
|
||||
--enforce-eager \
|
||||
--weight-transfer-config '{"backend": "nccl"}' \
|
||||
--load-format dummy
|
||||
|
||||
Then run this script:
|
||||
|
||||
$ python rlhf_http.py
|
||||
|
||||
The example performs the following steps:
|
||||
|
||||
* Load the training model on GPU 0.
|
||||
* Generate text using the vLLM server via OpenAI-compatible API. The output
|
||||
is expected to be nonsense because the server is initialized with dummy weights.
|
||||
* Initialize weight transfer via HTTP endpoint.
|
||||
* Broadcast the real weights from the training model to the vLLM server
|
||||
using NCCL.
|
||||
* Generate text again to show normal output after the weight update.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from openai import OpenAI
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
NCCLTrainerSendWeightsArgs,
|
||||
NCCLWeightTransferEngine,
|
||||
)
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
|
||||
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
|
||||
"""Generate completions using the OpenAI-compatible API."""
|
||||
results = []
|
||||
for prompt in prompts:
|
||||
response = client.completions.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=32,
|
||||
temperature=0,
|
||||
)
|
||||
results.append(response.choices[0].text)
|
||||
return results
|
||||
|
||||
|
||||
def init_weight_transfer_engine(
|
||||
base_url: str,
|
||||
master_address: str,
|
||||
master_port: int,
|
||||
rank_offset: int,
|
||||
world_size: int,
|
||||
) -> None:
|
||||
"""Initialize weight transfer via HTTP endpoint."""
|
||||
url = f"{base_url}/init_weight_transfer_engine"
|
||||
payload = {
|
||||
"init_info": dict(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
rank_offset=rank_offset,
|
||||
world_size=world_size,
|
||||
)
|
||||
}
|
||||
response = requests.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def update_weights(
|
||||
base_url: str,
|
||||
names: list[str],
|
||||
dtype_names: list[str],
|
||||
shapes: list[list[int]],
|
||||
packed: bool = False,
|
||||
) -> None:
|
||||
"""Update weights via HTTP endpoint."""
|
||||
url = f"{base_url}/update_weights"
|
||||
payload = {
|
||||
"update_info": dict(
|
||||
names=names,
|
||||
dtype_names=dtype_names,
|
||||
shapes=shapes,
|
||||
packed=packed,
|
||||
)
|
||||
}
|
||||
response = requests.post(url, json=payload, timeout=300)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def pause_generation(base_url: str) -> None:
|
||||
"""Pause generation via HTTP endpoint."""
|
||||
url = f"{base_url}/pause"
|
||||
response = requests.post(url, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def resume_generation(base_url: str) -> None:
|
||||
"""Resume generation via HTTP endpoint."""
|
||||
url = f"{base_url}/resume"
|
||||
response = requests.post(url, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def get_world_size(base_url: str) -> int:
|
||||
"""Get world size from the vLLM server."""
|
||||
url = f"{base_url}/get_world_size"
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.json()["world_size"]
|
||||
|
||||
|
||||
def main():
|
||||
# Get the inference world size from the vLLM server
|
||||
inference_world_size = get_world_size(BASE_URL)
|
||||
world_size = inference_world_size + 1 # +1 for the trainer
|
||||
device = f"cuda:{inference_world_size}"
|
||||
torch.accelerator.set_device_index(device)
|
||||
|
||||
# Load the training model
|
||||
print(f"Loading training model: {MODEL_NAME}")
|
||||
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
|
||||
train_model.to(device)
|
||||
|
||||
# Create OpenAI client pointing to the vLLM server
|
||||
client = OpenAI(
|
||||
base_url=f"{BASE_URL}/v1",
|
||||
api_key="EMPTY", # vLLM doesn't require an API key by default
|
||||
)
|
||||
|
||||
# Test prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Generate text before weight update. The output is expected to be nonsense
|
||||
# because the server is initialized with dummy weights.
|
||||
print("-" * 50)
|
||||
print("Generating text BEFORE weight update (expect nonsense):")
|
||||
print("-" * 50)
|
||||
outputs = generate_completions(client, MODEL_NAME, prompts)
|
||||
for prompt, generated_text in zip(prompts, outputs):
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Set up the communication channel between the training process and the
|
||||
# vLLM server. The trainer is rank 0, vLLM worker(s) start at rank_offset.
|
||||
master_address = get_ip()
|
||||
master_port = get_open_port()
|
||||
rank_offset = 1
|
||||
|
||||
print(f"Initializing weight transfer: master={master_address}:{master_port}")
|
||||
|
||||
# Initialize weight transfer on vLLM server (this is async, server will
|
||||
# wait for NCCL connection)
|
||||
import threading
|
||||
|
||||
init_thread = threading.Thread(
|
||||
target=init_weight_transfer_engine,
|
||||
args=(BASE_URL, master_address, master_port, rank_offset, world_size),
|
||||
)
|
||||
init_thread.start()
|
||||
|
||||
# Initialize NCCL process group on trainer side
|
||||
model_update_group = NCCLWeightTransferEngine.trainer_init(
|
||||
dict(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
world_size=world_size,
|
||||
),
|
||||
)
|
||||
|
||||
# Wait for init_weight_transfer_engine to complete
|
||||
init_thread.join()
|
||||
|
||||
# Pause generation before weight sync
|
||||
pause_generation(BASE_URL)
|
||||
|
||||
# Collect weight metadata for the update request
|
||||
names = []
|
||||
dtype_names = []
|
||||
shapes = []
|
||||
for name, p in train_model.named_parameters():
|
||||
names.append(name)
|
||||
dtype_names.append(str(p.dtype).split(".")[-1])
|
||||
shapes.append(list(p.shape))
|
||||
|
||||
# Start the update_weights call in a separate thread since it will block
|
||||
# waiting for NCCL broadcasts
|
||||
# packed=True enables efficient batched tensor broadcasting
|
||||
update_thread = threading.Thread(
|
||||
target=update_weights,
|
||||
args=(BASE_URL, names, dtype_names, shapes, True), # packed=True
|
||||
)
|
||||
update_thread.start()
|
||||
|
||||
# Broadcast all weights from trainer to vLLM workers
|
||||
print("Broadcasting weights via NCCL...")
|
||||
trainer_args = NCCLTrainerSendWeightsArgs(
|
||||
group=model_update_group,
|
||||
packed=True,
|
||||
)
|
||||
NCCLWeightTransferEngine.trainer_send_weights(
|
||||
iterator=train_model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
# Wait for update_weights to complete
|
||||
update_thread.join()
|
||||
|
||||
# Resume generation after weight sync
|
||||
resume_generation(BASE_URL)
|
||||
|
||||
# Generate text after weight update. The output is expected to be normal
|
||||
# because the real weights are now loaded.
|
||||
print("-" * 50)
|
||||
print("Generating text AFTER weight update:")
|
||||
print("-" * 50)
|
||||
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
|
||||
for prompt, generated_text in zip(prompts, outputs_updated):
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
149
examples/rl/rlhf_ipc.py
Normal file
149
examples/rl/rlhf_ipc.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray,
|
||||
with IPC-based weight syncing APIs
|
||||
|
||||
The script colocates the training and inference workloads onto the same GPU using Ray.
|
||||
|
||||
The example performs the following steps:
|
||||
|
||||
* Request a placement group of 1 GPU.
|
||||
* Place the inference model on the above GPU using the placement group.
|
||||
* Place and load the training model on the same GPU using the placement group.
|
||||
* Generate text from a list of prompts using the inference engine.
|
||||
* Update the weights of the training model and broadcast the updated weights
|
||||
to the inference engine by using CUDA IPC handles. Note that
|
||||
for demonstration purposes we simply zero out the weights.
|
||||
|
||||
This example assumes a single-node cluster with a single GPU,
|
||||
but can be extended to multiple GPUs.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
IPCTrainerSendWeightsArgs,
|
||||
IPCWeightTransferEngine,
|
||||
)
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""Configure the vLLM worker for Ray placement group execution."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
|
||||
# so that vLLM can manage its own device placement within the worker.
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
# Each worker uses 0.4 GPU so that two instances fit on the same GPU.
|
||||
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
|
||||
# needed for ipc handle serialization
|
||||
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Load the OPT-125M model onto GPU 0 for the training workload.
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
|
||||
@ray.remote
|
||||
class TrainModel:
|
||||
def __init__(self, llm_handle: ray.actor.ActorHandle):
|
||||
self.train_model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
)
|
||||
self.train_model.to("cuda:0")
|
||||
self.llm_handle = llm_handle
|
||||
|
||||
def init_weight_transfer(self):
|
||||
# IPC backend doesn't need initialization info
|
||||
ray.get(
|
||||
self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict()))
|
||||
)
|
||||
|
||||
def broadcast_weights(self, llm_handle: ray.actor.ActorHandle):
|
||||
"""Broadcast weights to the inference engine using IPC."""
|
||||
self.llm_handle = llm_handle
|
||||
trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=self.train_model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
|
||||
ray.init()
|
||||
|
||||
pg_colocate = placement_group([{"GPU": 1, "CPU": 0}])
|
||||
ray.get(pg_colocate.ready())
|
||||
|
||||
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_colocate,
|
||||
placement_group_capture_child_tasks=True,
|
||||
),
|
||||
)(MyLLM).remote(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
distributed_executor_backend="ray",
|
||||
gpu_memory_utilization=0.7,
|
||||
weight_transfer_config=WeightTransferConfig(backend="ipc"),
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
train_model = TrainModel.options(
|
||||
num_gpus=0.1,
|
||||
num_cpus=0,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_colocate, placement_group_capture_child_tasks=True
|
||||
),
|
||||
).remote(llm)
|
||||
|
||||
|
||||
# Generate text from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
ray.get(llm.sleep.remote(level=0))
|
||||
|
||||
ray.get(train_model.init_weight_transfer.remote())
|
||||
# Synchronize the updated weights to the inference engine using batched API.
|
||||
ray.get(train_model.broadcast_weights.remote(llm))
|
||||
|
||||
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||
|
||||
# Generate text with the updated model.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
print("-" * 50)
|
||||
for output in outputs_updated:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
216
examples/rl/rlhf_nccl.py
Normal file
216
examples/rl/rlhf_nccl.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates reinforcement learning using vLLM and Ray,
|
||||
with native weight syncing APIs at engine instance.
|
||||
|
||||
The script separates training and inference workloads onto distinct GPUs
|
||||
so that Ray can manage process placement and inter-process communication.
|
||||
A Hugging Face Transformer model occupies one GPU for training, whereas a
|
||||
2x tensor-parallel vLLM inference engine occupies two GPUs.
|
||||
|
||||
The example performs the following steps:
|
||||
* Load the training model on one gpu (scheduled via ray)
|
||||
* Initialize the inference model with dummy weights across
|
||||
two gpus using vLLM's tensor parallelism and Ray placement groups.
|
||||
* Generate gibberish from a list of prompts using the randomly initialized
|
||||
inference engine.
|
||||
* Update the weights of the training model and broadcast the updated weights
|
||||
to the inference engine by using a Ray collective RPC group.
|
||||
* Generating from the list of prompts after weight sync should result
|
||||
in sensible outputs.
|
||||
|
||||
This example assumes a single-node cluster with three GPUs, but Ray
|
||||
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
|
||||
workloads. Residual GPU activity interferes with vLLM memory profiling and
|
||||
causes unexpected behavior.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
NCCLTrainerSendWeightsArgs,
|
||||
NCCLWeightTransferEngine,
|
||||
)
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
# MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128"
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""Configure the vLLM worker for Ray placement group execution."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class TrainModel:
|
||||
"""Ray actor that wraps the training model on a dedicated GPU."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
).to("cuda:0")
|
||||
|
||||
self.port = get_open_port()
|
||||
self.master_address = get_ip()
|
||||
|
||||
def get_master_address_and_port(self):
|
||||
return self.master_address, self.port
|
||||
|
||||
def get_weight_metadata(self):
|
||||
"""Return weight names, dtypes, and shapes for weight transfer."""
|
||||
names = []
|
||||
dtype_names = []
|
||||
shapes = []
|
||||
for name, p in self.model.named_parameters():
|
||||
names.append(name)
|
||||
dtype_names.append(str(p.dtype).split(".")[-1])
|
||||
shapes.append(list(p.shape))
|
||||
return names, dtype_names, shapes
|
||||
|
||||
def init_weight_transfer_group(self, world_size):
|
||||
"""Initialize the NCCL process group for weight transfer."""
|
||||
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
|
||||
dict(
|
||||
master_address=self.master_address,
|
||||
master_port=self.port,
|
||||
world_size=world_size,
|
||||
),
|
||||
)
|
||||
|
||||
def broadcast_weights(self, packed: bool = True):
|
||||
"""Broadcast weights to the inference engine."""
|
||||
trainer_args = NCCLTrainerSendWeightsArgs(
|
||||
group=self.model_update_group,
|
||||
packed=packed,
|
||||
)
|
||||
NCCLWeightTransferEngine.trainer_send_weights(
|
||||
iterator=self.model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
|
||||
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||
# be placed on GPUs 1 and 2.
|
||||
ray.init()
|
||||
|
||||
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||
# Learn more about Ray placement groups:
|
||||
# https://docs.ray.io/en/latest/placement-groups.html
|
||||
# Launch the training model actor. Ray's resource scheduler will allocate
|
||||
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
|
||||
train_model = TrainModel.remote(MODEL_NAME)
|
||||
|
||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||
ray.get(pg_inference.ready())
|
||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_inference,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=0,
|
||||
)
|
||||
|
||||
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
||||
# start-up latency.
|
||||
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)
|
||||
# are now native to vLLM workers.
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=scheduling_inference,
|
||||
)(MyLLM).remote(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=2,
|
||||
data_parallel_size=1,
|
||||
distributed_executor_backend="ray",
|
||||
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
||||
load_format="dummy",
|
||||
quantization="fp8",
|
||||
)
|
||||
|
||||
# Generate text from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
|
||||
# Generate text with the initial model. The output is expected to be nonsense
|
||||
# because the weights are randomly initialized.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
ray.get(llm.sleep.remote(level=0))
|
||||
|
||||
# Set up the communication channel between the training process and the
|
||||
# inference engine.
|
||||
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
|
||||
|
||||
world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer
|
||||
inference_handle = llm.init_weight_transfer_engine.remote(
|
||||
dict(
|
||||
init_info=dict(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
rank_offset=1,
|
||||
world_size=world_size,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize weight transfer group on both the training actor and inference engine
|
||||
train_handle = train_model.init_weight_transfer_group.remote(world_size)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
# Synchronize the updated weights to the inference engine using batched API.
|
||||
# Collect all weight metadata from the training actor
|
||||
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
|
||||
|
||||
# Issue update_weights call with NCCL-specific update info
|
||||
# packed=True enables efficient batched tensor broadcasting
|
||||
inference_handle = llm.update_weights.remote(
|
||||
dict(
|
||||
update_info=dict(
|
||||
names=names,
|
||||
dtype_names=dtype_names,
|
||||
shapes=shapes,
|
||||
packed=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Broadcast all weights from trainer using the weight transfer API
|
||||
train_handle = train_model.broadcast_weights.remote(packed=True)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||
|
||||
# Generate text with the updated model. The output is expected to be normal
|
||||
# because the weights are updated.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
print("-" * 50)
|
||||
for output in outputs_updated:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
Reference in New Issue
Block a user