340 lines
11 KiB
Python
340 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
"""
|
||
RLHF with FSDP2 training (4 GPUs) and vLLM expert-parallel inference (4 GPUs).
|
||
|
||
8-GPU layout:
|
||
Training — 4 GPUs, PyTorch FSDP2 (fully_shard)
|
||
Inference — 4 GPUs, vLLM AsyncLLMEngine with expert parallelism +
|
||
data parallelism (TP=1, DP=4, enable_expert_parallel
|
||
→ EP_SIZE = TP×DP = 4)
|
||
|
||
FSDP workers are Ray actors that form a single FSDP2 process group.
|
||
Rank 0 gathers full parameters via DTensor.full_tensor() and broadcasts
|
||
them to the vLLM inference engine through the NCCL weight-transfer API.
|
||
|
||
The inference engine uses AsyncLLMEngine which automatically spawns
|
||
DP worker processes (no manual placement group needed). Weight sync
|
||
uses pause_generation / resume_generation.
|
||
|
||
Steps:
|
||
1. Launch 4 FSDP training workers.
|
||
2. Launch AsyncLLMEngine with EP+DP (dummy weights).
|
||
3. Generate from prompts → gibberish (random weights).
|
||
4. Pause generation, transfer weights from FSDP, resume.
|
||
5. Generate from prompts → sensible output (synced weights).
|
||
|
||
Assumes a single-node cluster with 8 GPUs.
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
import uuid
|
||
from dataclasses import asdict
|
||
|
||
import ray
|
||
import torch
|
||
import torch.distributed as dist
|
||
from huggingface_hub import snapshot_download
|
||
from torch.distributed.fsdp import fully_shard
|
||
from transformers import AutoModelForCausalLM
|
||
|
||
import vllm
|
||
from vllm import SamplingParams
|
||
from vllm.config import WeightTransferConfig
|
||
from vllm.distributed.weight_transfer.base import (
|
||
WeightTransferInitRequest,
|
||
WeightTransferUpdateRequest,
|
||
)
|
||
from vllm.distributed.weight_transfer.nccl_engine import (
|
||
NCCLTrainerSendWeightsArgs,
|
||
NCCLWeightTransferEngine,
|
||
NCCLWeightTransferInitInfo,
|
||
NCCLWeightTransferUpdateInfo,
|
||
)
|
||
from vllm.utils.network_utils import get_ip, get_open_port
|
||
from vllm.v1.executor import Executor
|
||
|
||
MODEL_NAME = "Qwen/Qwen3-30B-A3B"
|
||
|
||
FSDP_WORLD_SIZE = 4
|
||
INFERENCE_TP_SIZE = 1
|
||
INFERENCE_DP_SIZE = 4
|
||
|
||
|
||
@ray.remote(num_gpus=1)
|
||
class FSDPTrainWorker:
|
||
"""
|
||
One FSDP2 training worker per GPU. Four of these form the FSDP group.
|
||
Rank 0 additionally handles weight transfer to the vLLM engine.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str,
|
||
rank: int,
|
||
fsdp_world_size: int,
|
||
fsdp_master_addr: str,
|
||
fsdp_master_port: int,
|
||
):
|
||
self.rank = rank
|
||
|
||
os.environ["MASTER_ADDR"] = fsdp_master_addr
|
||
os.environ["MASTER_PORT"] = str(fsdp_master_port)
|
||
|
||
dist.init_process_group(backend="nccl", rank=rank, world_size=fsdp_world_size)
|
||
torch.accelerator.set_device_index(0)
|
||
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_name, torch_dtype=torch.bfloat16
|
||
)
|
||
|
||
self.weight_names = [n for n, _ in model.named_parameters()]
|
||
self.weight_dtype_names = [
|
||
str(p.dtype).split(".")[-1] for _, p in model.named_parameters()
|
||
]
|
||
self.weight_shapes = [list(p.shape) for _, p in model.named_parameters()]
|
||
|
||
for layer in model.model.layers:
|
||
fully_shard(layer)
|
||
fully_shard(model)
|
||
|
||
self.model = model
|
||
|
||
self.transfer_port = None
|
||
self.transfer_master_address = None
|
||
self.model_update_group = None
|
||
|
||
def get_rank(self):
|
||
return self.rank
|
||
|
||
# ---- weight-transfer setup (rank 0 only) ----
|
||
|
||
def setup_transfer_endpoint(self):
|
||
"""Create the NCCL rendezvous endpoint for weight transfer."""
|
||
assert self.rank == 0
|
||
self.transfer_port = get_open_port()
|
||
self.transfer_master_address = get_ip()
|
||
return self.transfer_master_address, self.transfer_port
|
||
|
||
def init_weight_transfer_group(self, transfer_world_size: int):
|
||
"""Join the weight-transfer NCCL group as rank 0 (the source)."""
|
||
assert self.rank == 0
|
||
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
|
||
dict(
|
||
master_address=self.transfer_master_address,
|
||
master_port=self.transfer_port,
|
||
world_size=transfer_world_size,
|
||
),
|
||
)
|
||
|
||
def get_weight_metadata(self):
|
||
"""Return weight names, dtypes, and shapes captured before FSDP wrapping."""
|
||
return self.weight_names, self.weight_dtype_names, self.weight_shapes
|
||
|
||
# ---- collective ops (ALL FSDP ranks must call concurrently) ----
|
||
|
||
def gather_and_broadcast_weights(self, packed: bool = True):
|
||
"""
|
||
All-gather full parameters and broadcast them to vLLM.
|
||
Only rank 0 performs the actual NCCL broadcast; others just
|
||
participate in the FSDP all-gather.
|
||
|
||
full_tensor() is a collective — all FSDP ranks must call it
|
||
for each parameter in the same order. Rank 0 additionally
|
||
feeds each gathered tensor to the weight-transfer engine.
|
||
"""
|
||
if self.rank == 0:
|
||
|
||
def _full_param_iter():
|
||
for name, param in self.model.named_parameters():
|
||
yield name, param.full_tensor()
|
||
|
||
trainer_args = NCCLTrainerSendWeightsArgs(
|
||
group=self.model_update_group,
|
||
packed=packed,
|
||
)
|
||
NCCLWeightTransferEngine.trainer_send_weights(
|
||
iterator=_full_param_iter(),
|
||
trainer_args=trainer_args,
|
||
)
|
||
else:
|
||
for _, param in self.model.named_parameters():
|
||
param.full_tensor()
|
||
|
||
|
||
def create_async_engine(**kwargs):
|
||
"""Create an AsyncLLMEngine directly (no subclass needed)."""
|
||
engine_args = vllm.AsyncEngineArgs(**kwargs)
|
||
vllm_config = engine_args.create_engine_config()
|
||
executor_class = Executor.get_class(vllm_config)
|
||
return vllm.AsyncLLMEngine(
|
||
vllm_config=vllm_config,
|
||
executor_class=executor_class,
|
||
log_requests=engine_args.enable_log_requests,
|
||
log_stats=not engine_args.disable_log_stats,
|
||
)
|
||
|
||
|
||
async def generate_batch(engine, prompts, sampling_params):
|
||
"""Generate completions for a batch of prompts."""
|
||
|
||
async def gen_one(prompt):
|
||
output = None
|
||
async for request_output in engine.generate(
|
||
{"prompt": prompt},
|
||
sampling_params,
|
||
request_id=str(uuid.uuid4()),
|
||
):
|
||
output = request_output
|
||
return output
|
||
|
||
return await asyncio.gather(*[gen_one(p) for p in prompts])
|
||
|
||
|
||
async def main():
|
||
ray.init()
|
||
|
||
# Download model weights to local/shared disk once.
|
||
local_model_path = snapshot_download(MODEL_NAME)
|
||
print(f"[init] Model downloaded to {local_model_path}")
|
||
|
||
# FSDP rendezvous address (single-node)
|
||
fsdp_master_addr = get_ip()
|
||
fsdp_master_port = get_open_port()
|
||
|
||
# Launch 4 FSDP training workers.
|
||
# Ray allocates 1 GPU per worker; AsyncLLMEngine's internal DP
|
||
# placement groups will land on the remaining 4 GPUs.
|
||
fsdp_workers = [
|
||
FSDPTrainWorker.remote(
|
||
local_model_path,
|
||
rank,
|
||
FSDP_WORLD_SIZE,
|
||
fsdp_master_addr,
|
||
fsdp_master_port,
|
||
)
|
||
for rank in range(FSDP_WORLD_SIZE)
|
||
]
|
||
ray.get([w.get_rank.remote() for w in fsdp_workers])
|
||
print(f"[init] {FSDP_WORLD_SIZE} FSDP training workers ready.")
|
||
|
||
# Launch vLLM with expert parallelism + data parallelism.
|
||
# AsyncLLMEngine with data_parallel_backend="ray" creates its own
|
||
# placement groups internally — no manual placement group needed.
|
||
print("[engine] Creating AsyncLLMEngine...")
|
||
engine = create_async_engine(
|
||
model=local_model_path,
|
||
enforce_eager=True,
|
||
tensor_parallel_size=INFERENCE_TP_SIZE,
|
||
data_parallel_size=INFERENCE_DP_SIZE,
|
||
enable_expert_parallel=True,
|
||
distributed_executor_backend="ray",
|
||
data_parallel_backend="ray",
|
||
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
||
load_format="dummy",
|
||
gpu_memory_utilization=0.7,
|
||
)
|
||
print("[engine] AsyncLLMEngine created.")
|
||
|
||
prompts = [
|
||
"Hello, my name is",
|
||
"The president of the United States is",
|
||
"The capital of France is",
|
||
"The future of AI is",
|
||
]
|
||
sampling_params = SamplingParams(temperature=0)
|
||
|
||
# Generate with dummy weights — expect gibberish.
|
||
print("[generate] Starting generation with dummy weights...")
|
||
outputs = await generate_batch(engine, prompts, sampling_params)
|
||
print("[generate] Generation complete.")
|
||
|
||
print("-" * 60)
|
||
print("BEFORE weight sync (dummy weights):")
|
||
print("-" * 60)
|
||
for output in outputs:
|
||
print(f"Prompt: {output.prompt!r}")
|
||
print(f"Generated: {output.outputs[0].text!r}")
|
||
print("-" * 60)
|
||
|
||
# --- Weight-transfer setup ---
|
||
print("[transfer] Setting up weight-transfer endpoint...")
|
||
transfer_addr, transfer_port = ray.get(
|
||
fsdp_workers[0].setup_transfer_endpoint.remote()
|
||
)
|
||
print(f"[transfer] Endpoint ready at {transfer_addr}:{transfer_port}")
|
||
|
||
transfer_world_size = INFERENCE_TP_SIZE * INFERENCE_DP_SIZE + 1
|
||
print(
|
||
f"[transfer] World size: {transfer_world_size} "
|
||
f"(1 trainer + {INFERENCE_TP_SIZE * INFERENCE_DP_SIZE} vLLM workers)"
|
||
)
|
||
|
||
print("[transfer] Initializing NCCL groups...")
|
||
train_handle = fsdp_workers[0].init_weight_transfer_group.remote(
|
||
transfer_world_size
|
||
)
|
||
await engine.init_weight_transfer_engine(
|
||
WeightTransferInitRequest(
|
||
init_info=asdict(
|
||
NCCLWeightTransferInitInfo(
|
||
master_address=transfer_addr,
|
||
master_port=transfer_port,
|
||
rank_offset=1,
|
||
world_size=transfer_world_size,
|
||
)
|
||
)
|
||
)
|
||
)
|
||
ray.get(train_handle)
|
||
print("[transfer] NCCL groups initialized.")
|
||
|
||
# --- Pause, transfer weights, resume ---
|
||
print("[sync] Pausing generation...")
|
||
await engine.pause_generation(mode="abort")
|
||
print("[sync] Generation paused.")
|
||
|
||
names, dtype_names, shapes = ray.get(fsdp_workers[0].get_weight_metadata.remote())
|
||
print(f"[sync] Got metadata for {len(names)} parameters.")
|
||
|
||
print("[sync] Broadcasting weights from FSDP → vLLM...")
|
||
broadcast_handles = [
|
||
w.gather_and_broadcast_weights.remote(packed=True) for w in fsdp_workers
|
||
]
|
||
await engine.update_weights(
|
||
WeightTransferUpdateRequest(
|
||
update_info=asdict(
|
||
NCCLWeightTransferUpdateInfo(
|
||
names=names,
|
||
dtype_names=dtype_names,
|
||
shapes=shapes,
|
||
packed=True,
|
||
)
|
||
)
|
||
)
|
||
)
|
||
ray.get(broadcast_handles)
|
||
print("[sync] Weight broadcast complete.")
|
||
|
||
print("[sync] Resuming generation...")
|
||
await engine.resume_generation()
|
||
print("[sync] Generation resumed.")
|
||
|
||
# Generate with synced weights — expect sensible output.
|
||
print("[generate] Starting generation with synced weights...")
|
||
outputs_updated = await generate_batch(engine, prompts, sampling_params)
|
||
print("[generate] Generation complete.")
|
||
|
||
print("-" * 60)
|
||
print("AFTER weight sync (real weights):")
|
||
print("-" * 60)
|
||
for output in outputs_updated:
|
||
print(f"Prompt: {output.prompt!r}")
|
||
print(f"Generated: {output.outputs[0].text!r}")
|
||
print("-" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|