[RL] Validation for pause_mode='keep' (#34992)
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
This commit is contained in:
@@ -104,7 +104,6 @@ steps:
|
||||
# NEW rlhf examples
|
||||
- cd new_weight_syncing
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
|
||||
|
||||
- label: Distributed Tests (8 GPUs)(H100)
|
||||
timeout_in_minutes: 10
|
||||
@@ -146,6 +145,7 @@ steps:
|
||||
num_devices: 2
|
||||
commands:
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- cd examples/offline_inference/new_weight_syncing && VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
|
||||
- VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
|
||||
@@ -26,14 +26,12 @@ workloads. Residual GPU activity interferes with vLLM memory profiling and
|
||||
causes unexpected behavior.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import vllm
|
||||
@@ -51,14 +49,15 @@ from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
from vllm.v1.executor import Executor
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
|
||||
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
|
||||
PAUSE_TOKEN_THRESHOLD = 10
|
||||
|
||||
|
||||
class MyLLM(vllm.AsyncLLMEngine):
|
||||
"""Configure the vLLM worker for Ray placement group execution."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1"
|
||||
engine_args = vllm.AsyncEngineArgs(**kwargs)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
@@ -68,26 +67,44 @@ class MyLLM(vllm.AsyncLLMEngine):
|
||||
log_requests=engine_args.enable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
self._generation_paused = False
|
||||
self._request_pause_flag = False
|
||||
|
||||
async def generate_with_retry(
|
||||
async def do_generate(
|
||||
self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
|
||||
) -> vllm.RequestOutput:
|
||||
finish_reason = "abort"
|
||||
while finish_reason == "abort":
|
||||
async for request_output in self.generate(
|
||||
{"prompt_token_ids": prompt_token_ids},
|
||||
sampling_params,
|
||||
request_id=str(uuid.uuid4()),
|
||||
) -> tuple[vllm.RequestOutput, int]:
|
||||
"""Generate a single request, setting the request pause flag once the
|
||||
token count reaches the threshold.
|
||||
|
||||
Returns (output, pause_token_index). pause_token_index is the number
|
||||
of tokens generated before the weight change, or -1 if no pause.
|
||||
"""
|
||||
pause_token_index = -1
|
||||
prev_token_count = 0
|
||||
async for request_output in self.generate(
|
||||
{"prompt_token_ids": prompt_token_ids},
|
||||
sampling_params,
|
||||
request_id=str(uuid.uuid4()),
|
||||
):
|
||||
output = request_output
|
||||
cur_token_count = len(output.outputs[0].token_ids)
|
||||
if (
|
||||
cur_token_count >= PAUSE_TOKEN_THRESHOLD
|
||||
and not self._request_pause_flag
|
||||
):
|
||||
output = request_output
|
||||
finish_reason = output.outputs[0].finish_reason
|
||||
if finish_reason == "abort":
|
||||
print(
|
||||
f"ABORT, prompt_token_ids: {prompt_token_ids}, "
|
||||
f"generated token_ids: {list(output.outputs[0].token_ids)}"
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids + list(output.outputs[0].token_ids)
|
||||
return output
|
||||
self._request_pause_flag = True
|
||||
if self._generation_paused and pause_token_index == -1:
|
||||
pause_token_index = prev_token_count
|
||||
prev_token_count = cur_token_count
|
||||
return output, pause_token_index
|
||||
|
||||
async def pause_after_n_tokens(self):
|
||||
"""Wait for any request to set the pause flag, then pause."""
|
||||
while not self._request_pause_flag:
|
||||
await asyncio.sleep(0)
|
||||
await super().pause_generation(mode="keep")
|
||||
await asyncio.sleep(0.2)
|
||||
self._generation_paused = True
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
@@ -95,6 +112,14 @@ class TrainModel:
|
||||
"""Ray actor that wraps the training model on a dedicated GPU."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
init_batch_invariance,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
# need to init all env vars for batch invariance which affect nccl ops
|
||||
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, dtype=torch.bfloat16
|
||||
).to("cuda:0")
|
||||
@@ -133,70 +158,80 @@ class TrainModel:
|
||||
packed=packed,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
|
||||
"""Greedy-decode max_new_tokens from the given context."""
|
||||
input_ids = torch.tensor([token_ids], device="cuda:0")
|
||||
output = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
)
|
||||
new_token_ids = output[0, len(token_ids) :].tolist()
|
||||
return new_token_ids
|
||||
|
||||
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||
# be placed on GPUs 1 and 2.
|
||||
ray.init()
|
||||
|
||||
ray.init(
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# enable batch invariance for deterministic outputs
|
||||
"VLLM_BATCH_INVARIANT": "1",
|
||||
# prevent ray from setting CUDA_VISIBLE_DEVICES
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Launch the training model actor. Ray's resource scheduler will allocate
|
||||
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
|
||||
train_model = TrainModel.remote(MODEL_NAME)
|
||||
|
||||
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||
# Learn more about Ray placement groups:
|
||||
# https://docs.ray.io/en/latest/placement-groups.html
|
||||
|
||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||
ray.get(pg_inference.ready())
|
||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_inference,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=0,
|
||||
)
|
||||
train_model = TrainModel.remote(MODEL_NAME_V2)
|
||||
|
||||
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
||||
# start-up latency.
|
||||
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)
|
||||
# are now native to vLLM workers.
|
||||
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
|
||||
# its own placement groups internally for each DP rank, so we must NOT
|
||||
# create an outer placement group (it would reserve GPUs and hide them
|
||||
# from the internal DP resource check).
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=scheduling_inference,
|
||||
)(MyLLM).remote(
|
||||
model=MODEL_NAME,
|
||||
model=MODEL_NAME_V1,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=2,
|
||||
max_model_len=8192,
|
||||
distributed_executor_backend="ray",
|
||||
load_format="dummy",
|
||||
attention_backend="FLASH_ATTN",
|
||||
gpu_memory_utilization=0.75,
|
||||
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
||||
)
|
||||
|
||||
# Generate text from the prompts.
|
||||
prompts = [
|
||||
"My name is",
|
||||
PROMPTS = [
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"The largest ocean on Earth is",
|
||||
"The speed of light in a vacuum is",
|
||||
"The chemical formula for water is",
|
||||
"The tallest mountain in the world is",
|
||||
"The first person to walk on the moon was",
|
||||
"The Great Wall of China was built to",
|
||||
"Photosynthesis is the process by which",
|
||||
"The theory of general relativity was proposed by",
|
||||
"The boiling point of water at sea level is",
|
||||
"The largest planet in our solar system is",
|
||||
"DNA stands for deoxyribonucleic acid and it",
|
||||
]
|
||||
|
||||
# Tokenize prompts to token IDs
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
prompt_token_ids_list = [
|
||||
tokenizer.encode(prompt, add_special_tokens=False) for prompt in prompts
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
|
||||
batch_prompt_token_ids = [
|
||||
tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
|
||||
]
|
||||
|
||||
sampling_params = [
|
||||
SamplingParams(temperature=0, max_tokens=2),
|
||||
SamplingParams(temperature=0, max_tokens=32),
|
||||
SamplingParams(temperature=0, max_tokens=32),
|
||||
SamplingParams(temperature=0, max_tokens=32),
|
||||
]
|
||||
|
||||
# Set up the communication channel between the training process and the
|
||||
# inference engine.
|
||||
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
|
||||
|
||||
world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2)
|
||||
world_size = 2 # 1 trainer + 1 inference worker
|
||||
inference_handle = llm.init_weight_transfer_engine.remote(
|
||||
WeightTransferInitRequest(
|
||||
init_info=asdict(
|
||||
@@ -215,22 +250,28 @@ train_handle = train_model.init_weight_transfer_group.remote(world_size)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
|
||||
generation_futures = [
|
||||
llm.generate_with_retry.remote(prompt_token_ids, params)
|
||||
for prompt_token_ids, params in zip(prompt_token_ids_list, sampling_params)
|
||||
]
|
||||
N_NEW_TOKENS = 100
|
||||
|
||||
finished, pending = ray.wait(generation_futures, num_returns=1)
|
||||
|
||||
# Pause generation in preparation for weight sync
|
||||
ray.get(llm.pause_generation.remote(wait_for_inflight_requests=False))
|
||||
|
||||
# Synchronize the updated weights to the inference engine using batched API.
|
||||
# Collect all weight metadata from the training actor
|
||||
# Collect weight metadata once
|
||||
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
|
||||
|
||||
# Issue update_weights call with NCCL-specific update info
|
||||
# packed=True enables efficient batched tensor broadcasting
|
||||
# ── Phase 1: concurrent requests with weight sync ───────────────────
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f"Prompts ({len(PROMPTS)}):")
|
||||
for p in PROMPTS:
|
||||
print(f" - {p!r}")
|
||||
print(f"{'=' * 50}")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
|
||||
)
|
||||
|
||||
gen_futures = [
|
||||
llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
|
||||
]
|
||||
|
||||
ray.get(llm.pause_after_n_tokens.remote())
|
||||
|
||||
inference_handle = llm.update_weights.remote(
|
||||
WeightTransferUpdateRequest(
|
||||
update_info=asdict(
|
||||
@@ -243,41 +284,76 @@ inference_handle = llm.update_weights.remote(
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Broadcast all weights from trainer using the weight transfer API
|
||||
train_handle = train_model.broadcast_weights.remote(packed=True)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
# Resume generation since weight sync is complete
|
||||
ray.get(llm.resume_generation.remote())
|
||||
results = ray.get(gen_futures)
|
||||
|
||||
# Get outputs separately - finished completed before pause, pending were paused/resumed
|
||||
finished_outputs = ray.get(finished)
|
||||
pending_outputs = ray.get(pending)
|
||||
for i, (output, pause_idx) in enumerate(results):
|
||||
all_token_ids = list(output.outputs[0].token_ids)
|
||||
before_text = tokenizer.decode(all_token_ids[:pause_idx])
|
||||
after_text = tokenizer.decode(all_token_ids[pause_idx:])
|
||||
print(f"\n Request {i} ({PROMPTS[i]!r}):")
|
||||
print(f" Old weights ({pause_idx} tokens): {before_text!r}")
|
||||
n_after = len(all_token_ids) - pause_idx
|
||||
print(f" New weights ({n_after} tokens): {after_text!r}")
|
||||
|
||||
# Requests that finished before the pause: all generation used original weights
|
||||
print("-" * 50)
|
||||
print("Requests that completed BEFORE weight change:")
|
||||
print("-" * 50)
|
||||
for output in finished_outputs:
|
||||
prompt_text = tokenizer.decode(output.prompt_token_ids)
|
||||
print(f"Prompt: {prompt_text!r}")
|
||||
print(f"Generated (with original weights): {output.outputs[0].text!r}")
|
||||
print("-" * 50)
|
||||
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
|
||||
print(f"\n{'=' * 50}")
|
||||
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
|
||||
print(f"{'=' * 50}")
|
||||
|
||||
# Requests that were paused mid-generation: some text before, some after weight change
|
||||
print("Requests that were PAUSED and RESUMED after weight change:")
|
||||
print("-" * 50)
|
||||
for output in pending_outputs:
|
||||
# Decode the full prompt token IDs (original + generated before pause)
|
||||
full_prompt_text = tokenizer.decode(output.prompt_token_ids)
|
||||
# Find the original prompt by checking which one this output started with
|
||||
original_prompt = next(p for p in prompts if full_prompt_text.startswith(p))
|
||||
# output.prompt_token_ids contains original prompt + tokens generated before pause
|
||||
# output.outputs[0].text is what was generated after resuming with new weights
|
||||
text_before_pause = full_prompt_text[len(original_prompt) :]
|
||||
text_after_pause = output.outputs[0].text
|
||||
print(f"Original prompt: {original_prompt!r}")
|
||||
print(f"Generated before weight change: {text_before_pause!r}")
|
||||
print(f"Generated after weight change: {text_after_pause!r}")
|
||||
print("-" * 50)
|
||||
ray.get(llm.shutdown.remote())
|
||||
ray.kill(llm)
|
||||
ray.kill(train_model)
|
||||
|
||||
llm_v2 = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
)(MyLLM).remote(
|
||||
model=MODEL_NAME_V2,
|
||||
enforce_eager=True,
|
||||
max_model_len=8192,
|
||||
gpu_memory_utilization=0.75,
|
||||
distributed_executor_backend="ray",
|
||||
attention_backend="FLASH_ATTN",
|
||||
)
|
||||
|
||||
val_futures = [
|
||||
llm_v2.do_generate.remote(
|
||||
list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
|
||||
SamplingParams(
|
||||
temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
|
||||
),
|
||||
)
|
||||
for output, pause_idx in results
|
||||
]
|
||||
val_results = ray.get(val_futures)
|
||||
|
||||
all_pass = True
|
||||
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
|
||||
expected = list(output.outputs[0].token_ids)[pause_idx:]
|
||||
actual = list(val_output.outputs[0].token_ids)
|
||||
match = actual == expected
|
||||
|
||||
if match:
|
||||
print(f" [PASS] {PROMPTS[i]!r}")
|
||||
else:
|
||||
all_pass = False
|
||||
print(f" [FAIL] {PROMPTS[i]!r}")
|
||||
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
|
||||
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
|
||||
for j, (e, a) in enumerate(zip(expected, actual)):
|
||||
if e != a:
|
||||
print(
|
||||
f" first divergence at output token {j}: "
|
||||
f"expected {e} ({tokenizer.decode([e])!r}) vs "
|
||||
f"actual {a} ({tokenizer.decode([a])!r})"
|
||||
)
|
||||
break
|
||||
|
||||
ray.get(llm_v2.shutdown.remote())
|
||||
ray.kill(llm_v2)
|
||||
assert all_pass, "Some prompts failed validation, see above for details"
|
||||
print("=" * 50)
|
||||
|
||||
Reference in New Issue
Block a user