[RL] Validation for pause_mode='keep' (#34992)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
This commit is contained in:
Aaron Hao
2026-02-23 13:30:56 -08:00
committed by GitHub
parent b8d8b7e934
commit 596ed1f02e
2 changed files with 180 additions and 104 deletions

View File

@@ -104,7 +104,6 @@ steps:
# NEW rlhf examples
- cd new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- label: Distributed Tests (8 GPUs)(H100)
timeout_in_minutes: 10
@@ -146,6 +145,7 @@ steps:
num_devices: 2
commands:
- pytest -v -s tests/distributed/test_context_parallel.py
- cd examples/offline_inference/new_weight_syncing && VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput
- pytest -v -s tests/v1/distributed/test_dbo.py

View File

@@ -26,14 +26,12 @@ workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import os
import asyncio
import uuid
from dataclasses import asdict
import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM, AutoTokenizer
import vllm
@@ -51,14 +49,15 @@ from vllm.distributed.weight_transfer.nccl_engine import (
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor
MODEL_NAME = "facebook/opt-125m"
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD = 10
class MyLLM(vllm.AsyncLLMEngine):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, **kwargs):
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1"
engine_args = vllm.AsyncEngineArgs(**kwargs)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
@@ -68,26 +67,44 @@ class MyLLM(vllm.AsyncLLMEngine):
log_requests=engine_args.enable_log_requests,
log_stats=not engine_args.disable_log_stats,
)
self._generation_paused = False
self._request_pause_flag = False
async def generate_with_retry(
async def do_generate(
self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
) -> vllm.RequestOutput:
finish_reason = "abort"
while finish_reason == "abort":
async for request_output in self.generate(
{"prompt_token_ids": prompt_token_ids},
sampling_params,
request_id=str(uuid.uuid4()),
) -> tuple[vllm.RequestOutput, int]:
"""Generate a single request, setting the request pause flag once the
token count reaches the threshold.
Returns (output, pause_token_index). pause_token_index is the number
of tokens generated before the weight change, or -1 if no pause.
"""
pause_token_index = -1
prev_token_count = 0
async for request_output in self.generate(
{"prompt_token_ids": prompt_token_ids},
sampling_params,
request_id=str(uuid.uuid4()),
):
output = request_output
cur_token_count = len(output.outputs[0].token_ids)
if (
cur_token_count >= PAUSE_TOKEN_THRESHOLD
and not self._request_pause_flag
):
output = request_output
finish_reason = output.outputs[0].finish_reason
if finish_reason == "abort":
print(
f"ABORT, prompt_token_ids: {prompt_token_ids}, "
f"generated token_ids: {list(output.outputs[0].token_ids)}"
)
prompt_token_ids = prompt_token_ids + list(output.outputs[0].token_ids)
return output
self._request_pause_flag = True
if self._generation_paused and pause_token_index == -1:
pause_token_index = prev_token_count
prev_token_count = cur_token_count
return output, pause_token_index
async def pause_after_n_tokens(self):
"""Wait for any request to set the pause flag, then pause."""
while not self._request_pause_flag:
await asyncio.sleep(0)
await super().pause_generation(mode="keep")
await asyncio.sleep(0.2)
self._generation_paused = True
@ray.remote(num_gpus=1)
@@ -95,6 +112,14 @@ class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU."""
def __init__(self, model_name: str):
from vllm.model_executor.layers.batch_invariant import (
init_batch_invariance,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
# need to init all env vars for batch invariance which affect nccl ops
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16
).to("cuda:0")
@@ -133,70 +158,80 @@ class TrainModel:
packed=packed,
)
@torch.inference_mode()
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
"""Greedy-decode max_new_tokens from the given context."""
input_ids = torch.tensor([token_ids], device="cuda:0")
output = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=False,
)
new_token_ids = output[0, len(token_ids) :].tolist()
return new_token_ids
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
ray.init()
ray.init(
runtime_env={
"env_vars": {
# enable batch invariance for deterministic outputs
"VLLM_BATCH_INVARIANT": "1",
# prevent ray from setting CUDA_VISIBLE_DEVICES
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
}
}
)
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME)
# Create a placement group that reserves GPU 12 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
train_model = TrainModel.remote(MODEL_NAME_V2)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)
# are now native to vLLM workers.
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
# from the internal DP resource check).
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model=MODEL_NAME,
model=MODEL_NAME_V1,
enforce_eager=True,
tensor_parallel_size=2,
max_model_len=8192,
distributed_executor_backend="ray",
load_format="dummy",
attention_backend="FLASH_ATTN",
gpu_memory_utilization=0.75,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
# Generate text from the prompts.
prompts = [
"My name is",
PROMPTS = [
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"The largest ocean on Earth is",
"The speed of light in a vacuum is",
"The chemical formula for water is",
"The tallest mountain in the world is",
"The first person to walk on the moon was",
"The Great Wall of China was built to",
"Photosynthesis is the process by which",
"The theory of general relativity was proposed by",
"The boiling point of water at sea level is",
"The largest planet in our solar system is",
"DNA stands for deoxyribonucleic acid and it",
]
# Tokenize prompts to token IDs
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
prompt_token_ids_list = [
tokenizer.encode(prompt, add_special_tokens=False) for prompt in prompts
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
batch_prompt_token_ids = [
tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
]
sampling_params = [
SamplingParams(temperature=0, max_tokens=2),
SamplingParams(temperature=0, max_tokens=32),
SamplingParams(temperature=0, max_tokens=32),
SamplingParams(temperature=0, max_tokens=32),
]
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2)
world_size = 2 # 1 trainer + 1 inference worker
inference_handle = llm.init_weight_transfer_engine.remote(
WeightTransferInitRequest(
init_info=asdict(
@@ -215,22 +250,28 @@ train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])
generation_futures = [
llm.generate_with_retry.remote(prompt_token_ids, params)
for prompt_token_ids, params in zip(prompt_token_ids_list, sampling_params)
]
N_NEW_TOKENS = 100
finished, pending = ray.wait(generation_futures, num_returns=1)
# Pause generation in preparation for weight sync
ray.get(llm.pause_generation.remote(wait_for_inflight_requests=False))
# Synchronize the updated weights to the inference engine using batched API.
# Collect all weight metadata from the training actor
# Collect weight metadata once
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
# ── Phase 1: concurrent requests with weight sync ───────────────────
print(f"\n{'=' * 50}")
print(f"Prompts ({len(PROMPTS)}):")
for p in PROMPTS:
print(f" - {p!r}")
print(f"{'=' * 50}")
sampling_params = SamplingParams(
temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
)
gen_futures = [
llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
]
ray.get(llm.pause_after_n_tokens.remote())
inference_handle = llm.update_weights.remote(
WeightTransferUpdateRequest(
update_info=asdict(
@@ -243,41 +284,76 @@ inference_handle = llm.update_weights.remote(
)
)
)
# Broadcast all weights from trainer using the weight transfer API
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
# Resume generation since weight sync is complete
ray.get(llm.resume_generation.remote())
results = ray.get(gen_futures)
# Get outputs separately - finished completed before pause, pending were paused/resumed
finished_outputs = ray.get(finished)
pending_outputs = ray.get(pending)
for i, (output, pause_idx) in enumerate(results):
all_token_ids = list(output.outputs[0].token_ids)
before_text = tokenizer.decode(all_token_ids[:pause_idx])
after_text = tokenizer.decode(all_token_ids[pause_idx:])
print(f"\n Request {i} ({PROMPTS[i]!r}):")
print(f" Old weights ({pause_idx} tokens): {before_text!r}")
n_after = len(all_token_ids) - pause_idx
print(f" New weights ({n_after} tokens): {after_text!r}")
# Requests that finished before the pause: all generation used original weights
print("-" * 50)
print("Requests that completed BEFORE weight change:")
print("-" * 50)
for output in finished_outputs:
prompt_text = tokenizer.decode(output.prompt_token_ids)
print(f"Prompt: {prompt_text!r}")
print(f"Generated (with original weights): {output.outputs[0].text!r}")
print("-" * 50)
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
print(f"\n{'=' * 50}")
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
print(f"{'=' * 50}")
# Requests that were paused mid-generation: some text before, some after weight change
print("Requests that were PAUSED and RESUMED after weight change:")
print("-" * 50)
for output in pending_outputs:
# Decode the full prompt token IDs (original + generated before pause)
full_prompt_text = tokenizer.decode(output.prompt_token_ids)
# Find the original prompt by checking which one this output started with
original_prompt = next(p for p in prompts if full_prompt_text.startswith(p))
# output.prompt_token_ids contains original prompt + tokens generated before pause
# output.outputs[0].text is what was generated after resuming with new weights
text_before_pause = full_prompt_text[len(original_prompt) :]
text_after_pause = output.outputs[0].text
print(f"Original prompt: {original_prompt!r}")
print(f"Generated before weight change: {text_before_pause!r}")
print(f"Generated after weight change: {text_after_pause!r}")
print("-" * 50)
ray.get(llm.shutdown.remote())
ray.kill(llm)
ray.kill(train_model)
llm_v2 = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(
model=MODEL_NAME_V2,
enforce_eager=True,
max_model_len=8192,
gpu_memory_utilization=0.75,
distributed_executor_backend="ray",
attention_backend="FLASH_ATTN",
)
val_futures = [
llm_v2.do_generate.remote(
list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
SamplingParams(
temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
),
)
for output, pause_idx in results
]
val_results = ray.get(val_futures)
all_pass = True
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
expected = list(output.outputs[0].token_ids)[pause_idx:]
actual = list(val_output.outputs[0].token_ids)
match = actual == expected
if match:
print(f" [PASS] {PROMPTS[i]!r}")
else:
all_pass = False
print(f" [FAIL] {PROMPTS[i]!r}")
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
for j, (e, a) in enumerate(zip(expected, actual)):
if e != a:
print(
f" first divergence at output token {j}: "
f"expected {e} ({tokenizer.decode([e])!r}) vs "
f"actual {a} ({tokenizer.decode([a])!r})"
)
break
ray.get(llm_v2.shutdown.remote())
ray.kill(llm_v2)
assert all_pass, "Some prompts failed validation, see above for details"
print("=" * 50)