340 lines
11 KiB
Python
340 lines
11 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
|||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|||
|
|
"""
|
|||
|
|
RLHF with FSDP2 training (4 GPUs) and vLLM expert-parallel inference (4 GPUs).
|
|||
|
|
|
|||
|
|
8-GPU layout:
|
|||
|
|
Training — 4 GPUs, PyTorch FSDP2 (fully_shard)
|
|||
|
|
Inference — 4 GPUs, vLLM AsyncLLMEngine with expert parallelism +
|
|||
|
|
data parallelism (TP=1, DP=4, enable_expert_parallel
|
|||
|
|
→ EP_SIZE = TP×DP = 4)
|
|||
|
|
|
|||
|
|
FSDP workers are Ray actors that form a single FSDP2 process group.
|
|||
|
|
Rank 0 gathers full parameters via DTensor.full_tensor() and broadcasts
|
|||
|
|
them to the vLLM inference engine through the NCCL weight-transfer API.
|
|||
|
|
|
|||
|
|
The inference engine uses AsyncLLMEngine which automatically spawns
|
|||
|
|
DP worker processes (no manual placement group needed). Weight sync
|
|||
|
|
uses pause_generation / resume_generation.
|
|||
|
|
|
|||
|
|
Steps:
|
|||
|
|
1. Launch 4 FSDP training workers.
|
|||
|
|
2. Launch AsyncLLMEngine with EP+DP (dummy weights).
|
|||
|
|
3. Generate from prompts → gibberish (random weights).
|
|||
|
|
4. Pause generation, transfer weights from FSDP, resume.
|
|||
|
|
5. Generate from prompts → sensible output (synced weights).
|
|||
|
|
|
|||
|
|
Assumes a single-node cluster with 8 GPUs.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import os
|
|||
|
|
import uuid
|
|||
|
|
from dataclasses import asdict
|
|||
|
|
|
|||
|
|
import ray
|
|||
|
|
import torch
|
|||
|
|
import torch.distributed as dist
|
|||
|
|
from huggingface_hub import snapshot_download
|
|||
|
|
from torch.distributed.fsdp import fully_shard
|
|||
|
|
from transformers import AutoModelForCausalLM
|
|||
|
|
|
|||
|
|
import vllm
|
|||
|
|
from vllm import SamplingParams
|
|||
|
|
from vllm.config import WeightTransferConfig
|
|||
|
|
from vllm.distributed.weight_transfer.base import (
|
|||
|
|
WeightTransferInitRequest,
|
|||
|
|
WeightTransferUpdateRequest,
|
|||
|
|
)
|
|||
|
|
from vllm.distributed.weight_transfer.nccl_engine import (
|
|||
|
|
NCCLTrainerSendWeightsArgs,
|
|||
|
|
NCCLWeightTransferEngine,
|
|||
|
|
NCCLWeightTransferInitInfo,
|
|||
|
|
NCCLWeightTransferUpdateInfo,
|
|||
|
|
)
|
|||
|
|
from vllm.utils.network_utils import get_ip, get_open_port
|
|||
|
|
from vllm.v1.executor import Executor
|
|||
|
|
|
|||
|
|
MODEL_NAME = "Qwen/Qwen3-30B-A3B"
|
|||
|
|
|
|||
|
|
FSDP_WORLD_SIZE = 4
|
|||
|
|
INFERENCE_TP_SIZE = 1
|
|||
|
|
INFERENCE_DP_SIZE = 4
|
|||
|
|
|
|||
|
|
|
|||
|
|
@ray.remote(num_gpus=1)
|
|||
|
|
class FSDPTrainWorker:
|
|||
|
|
"""
|
|||
|
|
One FSDP2 training worker per GPU. Four of these form the FSDP group.
|
|||
|
|
Rank 0 additionally handles weight transfer to the vLLM engine.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
model_name: str,
|
|||
|
|
rank: int,
|
|||
|
|
fsdp_world_size: int,
|
|||
|
|
fsdp_master_addr: str,
|
|||
|
|
fsdp_master_port: int,
|
|||
|
|
):
|
|||
|
|
self.rank = rank
|
|||
|
|
|
|||
|
|
os.environ["MASTER_ADDR"] = fsdp_master_addr
|
|||
|
|
os.environ["MASTER_PORT"] = str(fsdp_master_port)
|
|||
|
|
|
|||
|
|
dist.init_process_group(backend="nccl", rank=rank, world_size=fsdp_world_size)
|
|||
|
|
torch.accelerator.set_device_index(0)
|
|||
|
|
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
|
model_name, torch_dtype=torch.bfloat16
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.weight_names = [n for n, _ in model.named_parameters()]
|
|||
|
|
self.weight_dtype_names = [
|
|||
|
|
str(p.dtype).split(".")[-1] for _, p in model.named_parameters()
|
|||
|
|
]
|
|||
|
|
self.weight_shapes = [list(p.shape) for _, p in model.named_parameters()]
|
|||
|
|
|
|||
|
|
for layer in model.model.layers:
|
|||
|
|
fully_shard(layer)
|
|||
|
|
fully_shard(model)
|
|||
|
|
|
|||
|
|
self.model = model
|
|||
|
|
|
|||
|
|
self.transfer_port = None
|
|||
|
|
self.transfer_master_address = None
|
|||
|
|
self.model_update_group = None
|
|||
|
|
|
|||
|
|
def get_rank(self):
|
|||
|
|
return self.rank
|
|||
|
|
|
|||
|
|
# ---- weight-transfer setup (rank 0 only) ----
|
|||
|
|
|
|||
|
|
def setup_transfer_endpoint(self):
|
|||
|
|
"""Create the NCCL rendezvous endpoint for weight transfer."""
|
|||
|
|
assert self.rank == 0
|
|||
|
|
self.transfer_port = get_open_port()
|
|||
|
|
self.transfer_master_address = get_ip()
|
|||
|
|
return self.transfer_master_address, self.transfer_port
|
|||
|
|
|
|||
|
|
def init_weight_transfer_group(self, transfer_world_size: int):
|
|||
|
|
"""Join the weight-transfer NCCL group as rank 0 (the source)."""
|
|||
|
|
assert self.rank == 0
|
|||
|
|
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
|
|||
|
|
dict(
|
|||
|
|
master_address=self.transfer_master_address,
|
|||
|
|
master_port=self.transfer_port,
|
|||
|
|
world_size=transfer_world_size,
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def get_weight_metadata(self):
|
|||
|
|
"""Return weight names, dtypes, and shapes captured before FSDP wrapping."""
|
|||
|
|
return self.weight_names, self.weight_dtype_names, self.weight_shapes
|
|||
|
|
|
|||
|
|
# ---- collective ops (ALL FSDP ranks must call concurrently) ----
|
|||
|
|
|
|||
|
|
def gather_and_broadcast_weights(self, packed: bool = True):
|
|||
|
|
"""
|
|||
|
|
All-gather full parameters and broadcast them to vLLM.
|
|||
|
|
Only rank 0 performs the actual NCCL broadcast; others just
|
|||
|
|
participate in the FSDP all-gather.
|
|||
|
|
|
|||
|
|
full_tensor() is a collective — all FSDP ranks must call it
|
|||
|
|
for each parameter in the same order. Rank 0 additionally
|
|||
|
|
feeds each gathered tensor to the weight-transfer engine.
|
|||
|
|
"""
|
|||
|
|
if self.rank == 0:
|
|||
|
|
|
|||
|
|
def _full_param_iter():
|
|||
|
|
for name, param in self.model.named_parameters():
|
|||
|
|
yield name, param.full_tensor()
|
|||
|
|
|
|||
|
|
trainer_args = NCCLTrainerSendWeightsArgs(
|
|||
|
|
group=self.model_update_group,
|
|||
|
|
packed=packed,
|
|||
|
|
)
|
|||
|
|
NCCLWeightTransferEngine.trainer_send_weights(
|
|||
|
|
iterator=_full_param_iter(),
|
|||
|
|
trainer_args=trainer_args,
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
for _, param in self.model.named_parameters():
|
|||
|
|
param.full_tensor()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_async_engine(**kwargs):
|
|||
|
|
"""Create an AsyncLLMEngine directly (no subclass needed)."""
|
|||
|
|
engine_args = vllm.AsyncEngineArgs(**kwargs)
|
|||
|
|
vllm_config = engine_args.create_engine_config()
|
|||
|
|
executor_class = Executor.get_class(vllm_config)
|
|||
|
|
return vllm.AsyncLLMEngine(
|
|||
|
|
vllm_config=vllm_config,
|
|||
|
|
executor_class=executor_class,
|
|||
|
|
log_requests=engine_args.enable_log_requests,
|
|||
|
|
log_stats=not engine_args.disable_log_stats,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def generate_batch(engine, prompts, sampling_params):
|
|||
|
|
"""Generate completions for a batch of prompts."""
|
|||
|
|
|
|||
|
|
async def gen_one(prompt):
|
|||
|
|
output = None
|
|||
|
|
async for request_output in engine.generate(
|
|||
|
|
{"prompt": prompt},
|
|||
|
|
sampling_params,
|
|||
|
|
request_id=str(uuid.uuid4()),
|
|||
|
|
):
|
|||
|
|
output = request_output
|
|||
|
|
return output
|
|||
|
|
|
|||
|
|
return await asyncio.gather(*[gen_one(p) for p in prompts])
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def main():
|
|||
|
|
ray.init()
|
|||
|
|
|
|||
|
|
# Download model weights to local/shared disk once.
|
|||
|
|
local_model_path = snapshot_download(MODEL_NAME)
|
|||
|
|
print(f"[init] Model downloaded to {local_model_path}")
|
|||
|
|
|
|||
|
|
# FSDP rendezvous address (single-node)
|
|||
|
|
fsdp_master_addr = get_ip()
|
|||
|
|
fsdp_master_port = get_open_port()
|
|||
|
|
|
|||
|
|
# Launch 4 FSDP training workers.
|
|||
|
|
# Ray allocates 1 GPU per worker; AsyncLLMEngine's internal DP
|
|||
|
|
# placement groups will land on the remaining 4 GPUs.
|
|||
|
|
fsdp_workers = [
|
|||
|
|
FSDPTrainWorker.remote(
|
|||
|
|
local_model_path,
|
|||
|
|
rank,
|
|||
|
|
FSDP_WORLD_SIZE,
|
|||
|
|
fsdp_master_addr,
|
|||
|
|
fsdp_master_port,
|
|||
|
|
)
|
|||
|
|
for rank in range(FSDP_WORLD_SIZE)
|
|||
|
|
]
|
|||
|
|
ray.get([w.get_rank.remote() for w in fsdp_workers])
|
|||
|
|
print(f"[init] {FSDP_WORLD_SIZE} FSDP training workers ready.")
|
|||
|
|
|
|||
|
|
# Launch vLLM with expert parallelism + data parallelism.
|
|||
|
|
# AsyncLLMEngine with data_parallel_backend="ray" creates its own
|
|||
|
|
# placement groups internally — no manual placement group needed.
|
|||
|
|
print("[engine] Creating AsyncLLMEngine...")
|
|||
|
|
engine = create_async_engine(
|
|||
|
|
model=local_model_path,
|
|||
|
|
enforce_eager=True,
|
|||
|
|
tensor_parallel_size=INFERENCE_TP_SIZE,
|
|||
|
|
data_parallel_size=INFERENCE_DP_SIZE,
|
|||
|
|
enable_expert_parallel=True,
|
|||
|
|
distributed_executor_backend="ray",
|
|||
|
|
data_parallel_backend="ray",
|
|||
|
|
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
|||
|
|
load_format="dummy",
|
|||
|
|
gpu_memory_utilization=0.7,
|
|||
|
|
)
|
|||
|
|
print("[engine] AsyncLLMEngine created.")
|
|||
|
|
|
|||
|
|
prompts = [
|
|||
|
|
"Hello, my name is",
|
|||
|
|
"The president of the United States is",
|
|||
|
|
"The capital of France is",
|
|||
|
|
"The future of AI is",
|
|||
|
|
]
|
|||
|
|
sampling_params = SamplingParams(temperature=0)
|
|||
|
|
|
|||
|
|
# Generate with dummy weights — expect gibberish.
|
|||
|
|
print("[generate] Starting generation with dummy weights...")
|
|||
|
|
outputs = await generate_batch(engine, prompts, sampling_params)
|
|||
|
|
print("[generate] Generation complete.")
|
|||
|
|
|
|||
|
|
print("-" * 60)
|
|||
|
|
print("BEFORE weight sync (dummy weights):")
|
|||
|
|
print("-" * 60)
|
|||
|
|
for output in outputs:
|
|||
|
|
print(f"Prompt: {output.prompt!r}")
|
|||
|
|
print(f"Generated: {output.outputs[0].text!r}")
|
|||
|
|
print("-" * 60)
|
|||
|
|
|
|||
|
|
# --- Weight-transfer setup ---
|
|||
|
|
print("[transfer] Setting up weight-transfer endpoint...")
|
|||
|
|
transfer_addr, transfer_port = ray.get(
|
|||
|
|
fsdp_workers[0].setup_transfer_endpoint.remote()
|
|||
|
|
)
|
|||
|
|
print(f"[transfer] Endpoint ready at {transfer_addr}:{transfer_port}")
|
|||
|
|
|
|||
|
|
transfer_world_size = INFERENCE_TP_SIZE * INFERENCE_DP_SIZE + 1
|
|||
|
|
print(
|
|||
|
|
f"[transfer] World size: {transfer_world_size} "
|
|||
|
|
f"(1 trainer + {INFERENCE_TP_SIZE * INFERENCE_DP_SIZE} vLLM workers)"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("[transfer] Initializing NCCL groups...")
|
|||
|
|
train_handle = fsdp_workers[0].init_weight_transfer_group.remote(
|
|||
|
|
transfer_world_size
|
|||
|
|
)
|
|||
|
|
await engine.init_weight_transfer_engine(
|
|||
|
|
WeightTransferInitRequest(
|
|||
|
|
init_info=asdict(
|
|||
|
|
NCCLWeightTransferInitInfo(
|
|||
|
|
master_address=transfer_addr,
|
|||
|
|
master_port=transfer_port,
|
|||
|
|
rank_offset=1,
|
|||
|
|
world_size=transfer_world_size,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
ray.get(train_handle)
|
|||
|
|
print("[transfer] NCCL groups initialized.")
|
|||
|
|
|
|||
|
|
# --- Pause, transfer weights, resume ---
|
|||
|
|
print("[sync] Pausing generation...")
|
|||
|
|
await engine.pause_generation(mode="abort")
|
|||
|
|
print("[sync] Generation paused.")
|
|||
|
|
|
|||
|
|
names, dtype_names, shapes = ray.get(fsdp_workers[0].get_weight_metadata.remote())
|
|||
|
|
print(f"[sync] Got metadata for {len(names)} parameters.")
|
|||
|
|
|
|||
|
|
print("[sync] Broadcasting weights from FSDP → vLLM...")
|
|||
|
|
broadcast_handles = [
|
|||
|
|
w.gather_and_broadcast_weights.remote(packed=True) for w in fsdp_workers
|
|||
|
|
]
|
|||
|
|
await engine.update_weights(
|
|||
|
|
WeightTransferUpdateRequest(
|
|||
|
|
update_info=asdict(
|
|||
|
|
NCCLWeightTransferUpdateInfo(
|
|||
|
|
names=names,
|
|||
|
|
dtype_names=dtype_names,
|
|||
|
|
shapes=shapes,
|
|||
|
|
packed=True,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
ray.get(broadcast_handles)
|
|||
|
|
print("[sync] Weight broadcast complete.")
|
|||
|
|
|
|||
|
|
print("[sync] Resuming generation...")
|
|||
|
|
await engine.resume_generation()
|
|||
|
|
print("[sync] Generation resumed.")
|
|||
|
|
|
|||
|
|
# Generate with synced weights — expect sensible output.
|
|||
|
|
print("[generate] Starting generation with synced weights...")
|
|||
|
|
outputs_updated = await generate_batch(engine, prompts, sampling_params)
|
|||
|
|
print("[generate] Generation complete.")
|
|||
|
|
|
|||
|
|
print("-" * 60)
|
|||
|
|
print("AFTER weight sync (real weights):")
|
|||
|
|
print("-" * 60)
|
|||
|
|
for output in outputs_updated:
|
|||
|
|
print(f"Prompt: {output.prompt!r}")
|
|||
|
|
print(f"Generated: {output.outputs[0].text!r}")
|
|||
|
|
print("-" * 60)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(main())
|