diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index f15e5018b..df748a5fc 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -104,7 +104,6 @@ steps: # NEW rlhf examples - cd new_weight_syncing - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - label: Distributed Tests (8 GPUs)(H100) timeout_in_minutes: 10 @@ -146,6 +145,7 @@ steps: num_devices: 2 commands: - pytest -v -s tests/distributed/test_context_parallel.py + - cd examples/offline_inference/new_weight_syncing && VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput - pytest -v -s tests/v1/distributed/test_dbo.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index 835c16a7f..8714eb92b 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -26,14 +26,12 @@ workloads. Residual GPU activity interferes with vLLM memory profiling and causes unexpected behavior. """ -import os +import asyncio import uuid from dataclasses import asdict import ray import torch -from ray.util.placement_group import placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from transformers import AutoModelForCausalLM, AutoTokenizer import vllm @@ -51,14 +49,15 @@ from vllm.distributed.weight_transfer.nccl_engine import ( from vllm.utils.network_utils import get_ip, get_open_port from vllm.v1.executor import Executor -MODEL_NAME = "facebook/opt-125m" +MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base" +MODEL_NAME_V2 = "Qwen/Qwen3-1.7B" +PAUSE_TOKEN_THRESHOLD = 10 class MyLLM(vllm.AsyncLLMEngine): """Configure the vLLM worker for Ray placement group execution.""" def __init__(self, **kwargs): - os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1" engine_args = vllm.AsyncEngineArgs(**kwargs) vllm_config = engine_args.create_engine_config() executor_class = Executor.get_class(vllm_config) @@ -68,26 +67,44 @@ class MyLLM(vllm.AsyncLLMEngine): log_requests=engine_args.enable_log_requests, log_stats=not engine_args.disable_log_stats, ) + self._generation_paused = False + self._request_pause_flag = False - async def generate_with_retry( + async def do_generate( self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams - ) -> vllm.RequestOutput: - finish_reason = "abort" - while finish_reason == "abort": - async for request_output in self.generate( - {"prompt_token_ids": prompt_token_ids}, - sampling_params, - request_id=str(uuid.uuid4()), + ) -> tuple[vllm.RequestOutput, int]: + """Generate a single request, setting the request pause flag once the + token count reaches the threshold. + + Returns (output, pause_token_index). pause_token_index is the number + of tokens generated before the weight change, or -1 if no pause. + """ + pause_token_index = -1 + prev_token_count = 0 + async for request_output in self.generate( + {"prompt_token_ids": prompt_token_ids}, + sampling_params, + request_id=str(uuid.uuid4()), + ): + output = request_output + cur_token_count = len(output.outputs[0].token_ids) + if ( + cur_token_count >= PAUSE_TOKEN_THRESHOLD + and not self._request_pause_flag ): - output = request_output - finish_reason = output.outputs[0].finish_reason - if finish_reason == "abort": - print( - f"ABORT, prompt_token_ids: {prompt_token_ids}, " - f"generated token_ids: {list(output.outputs[0].token_ids)}" - ) - prompt_token_ids = prompt_token_ids + list(output.outputs[0].token_ids) - return output + self._request_pause_flag = True + if self._generation_paused and pause_token_index == -1: + pause_token_index = prev_token_count + prev_token_count = cur_token_count + return output, pause_token_index + + async def pause_after_n_tokens(self): + """Wait for any request to set the pause flag, then pause.""" + while not self._request_pause_flag: + await asyncio.sleep(0) + await super().pause_generation(mode="keep") + await asyncio.sleep(0.2) + self._generation_paused = True @ray.remote(num_gpus=1) @@ -95,6 +112,14 @@ class TrainModel: """Ray actor that wraps the training model on a dedicated GPU.""" def __init__(self, model_name: str): + from vllm.model_executor.layers.batch_invariant import ( + init_batch_invariance, + ) + from vllm.v1.attention.backends.registry import AttentionBackendEnum + + # need to init all env vars for batch invariance which affect nccl ops + init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) + self.model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16 ).to("cuda:0") @@ -133,70 +158,80 @@ class TrainModel: packed=packed, ) + @torch.inference_mode() + def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]: + """Greedy-decode max_new_tokens from the given context.""" + input_ids = torch.tensor([token_ids], device="cuda:0") + output = self.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=False, + ) + new_token_ids = output[0, len(token_ids) :].tolist() + return new_token_ids -# Initialize Ray and set the visible devices. The vLLM engine will -# be placed on GPUs 1 and 2. -ray.init() + +ray.init( + runtime_env={ + "env_vars": { + # enable batch invariance for deterministic outputs + "VLLM_BATCH_INVARIANT": "1", + # prevent ray from setting CUDA_VISIBLE_DEVICES + "RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1", + } + } +) # Launch the training model actor. Ray's resource scheduler will allocate # 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. -train_model = TrainModel.remote(MODEL_NAME) - -# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. -# Learn more about Ray placement groups: -# https://docs.ray.io/en/latest/placement-groups.html - -pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) -ray.get(pg_inference.ready()) -scheduling_inference = PlacementGroupSchedulingStrategy( - placement_group=pg_inference, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=0, -) +train_model = TrainModel.remote(MODEL_NAME_V2) # Launch the vLLM inference engine. The `enforce_eager` flag reduces # start-up latency. -# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights) -# are now native to vLLM workers. +# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates +# its own placement groups internally for each DP rank, so we must NOT +# create an outer placement group (it would reserve GPUs and hide them +# from the internal DP resource check). llm = ray.remote( num_cpus=0, num_gpus=0, - scheduling_strategy=scheduling_inference, )(MyLLM).remote( - model=MODEL_NAME, + model=MODEL_NAME_V1, enforce_eager=True, - tensor_parallel_size=2, + max_model_len=8192, distributed_executor_backend="ray", - load_format="dummy", + attention_backend="FLASH_ATTN", + gpu_memory_utilization=0.75, weight_transfer_config=WeightTransferConfig(backend="nccl"), ) -# Generate text from the prompts. -prompts = [ - "My name is", +PROMPTS = [ "The president of the United States is", "The capital of France is", - "The future of AI is", + "The largest ocean on Earth is", + "The speed of light in a vacuum is", + "The chemical formula for water is", + "The tallest mountain in the world is", + "The first person to walk on the moon was", + "The Great Wall of China was built to", + "Photosynthesis is the process by which", + "The theory of general relativity was proposed by", + "The boiling point of water at sea level is", + "The largest planet in our solar system is", + "DNA stands for deoxyribonucleic acid and it", ] -# Tokenize prompts to token IDs -tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) -prompt_token_ids_list = [ - tokenizer.encode(prompt, add_special_tokens=False) for prompt in prompts +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1) +batch_prompt_token_ids = [ + tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS ] -sampling_params = [ - SamplingParams(temperature=0, max_tokens=2), - SamplingParams(temperature=0, max_tokens=32), - SamplingParams(temperature=0, max_tokens=32), - SamplingParams(temperature=0, max_tokens=32), -] # Set up the communication channel between the training process and the # inference engine. master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) -world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2) +world_size = 2 # 1 trainer + 1 inference worker inference_handle = llm.init_weight_transfer_engine.remote( WeightTransferInitRequest( init_info=asdict( @@ -215,22 +250,28 @@ train_handle = train_model.init_weight_transfer_group.remote(world_size) ray.get([train_handle, inference_handle]) -generation_futures = [ - llm.generate_with_retry.remote(prompt_token_ids, params) - for prompt_token_ids, params in zip(prompt_token_ids_list, sampling_params) -] +N_NEW_TOKENS = 100 -finished, pending = ray.wait(generation_futures, num_returns=1) - -# Pause generation in preparation for weight sync -ray.get(llm.pause_generation.remote(wait_for_inflight_requests=False)) - -# Synchronize the updated weights to the inference engine using batched API. -# Collect all weight metadata from the training actor +# Collect weight metadata once names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote()) -# Issue update_weights call with NCCL-specific update info -# packed=True enables efficient batched tensor broadcasting +# ── Phase 1: concurrent requests with weight sync ─────────────────── +print(f"\n{'=' * 50}") +print(f"Prompts ({len(PROMPTS)}):") +for p in PROMPTS: + print(f" - {p!r}") +print(f"{'=' * 50}") + +sampling_params = SamplingParams( + temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS +) + +gen_futures = [ + llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids +] + +ray.get(llm.pause_after_n_tokens.remote()) + inference_handle = llm.update_weights.remote( WeightTransferUpdateRequest( update_info=asdict( @@ -243,41 +284,76 @@ inference_handle = llm.update_weights.remote( ) ) ) - -# Broadcast all weights from trainer using the weight transfer API train_handle = train_model.broadcast_weights.remote(packed=True) ray.get([train_handle, inference_handle]) -# Resume generation since weight sync is complete ray.get(llm.resume_generation.remote()) +results = ray.get(gen_futures) -# Get outputs separately - finished completed before pause, pending were paused/resumed -finished_outputs = ray.get(finished) -pending_outputs = ray.get(pending) +for i, (output, pause_idx) in enumerate(results): + all_token_ids = list(output.outputs[0].token_ids) + before_text = tokenizer.decode(all_token_ids[:pause_idx]) + after_text = tokenizer.decode(all_token_ids[pause_idx:]) + print(f"\n Request {i} ({PROMPTS[i]!r}):") + print(f" Old weights ({pause_idx} tokens): {before_text!r}") + n_after = len(all_token_ids) - pause_idx + print(f" New weights ({n_after} tokens): {after_text!r}") -# Requests that finished before the pause: all generation used original weights -print("-" * 50) -print("Requests that completed BEFORE weight change:") -print("-" * 50) -for output in finished_outputs: - prompt_text = tokenizer.decode(output.prompt_token_ids) - print(f"Prompt: {prompt_text!r}") - print(f"Generated (with original weights): {output.outputs[0].text!r}") - print("-" * 50) +# ── Phase 2: validate with a fresh V2 vLLM instance ──────────────── +print(f"\n{'=' * 50}") +print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance") +print(f"{'=' * 50}") -# Requests that were paused mid-generation: some text before, some after weight change -print("Requests that were PAUSED and RESUMED after weight change:") -print("-" * 50) -for output in pending_outputs: - # Decode the full prompt token IDs (original + generated before pause) - full_prompt_text = tokenizer.decode(output.prompt_token_ids) - # Find the original prompt by checking which one this output started with - original_prompt = next(p for p in prompts if full_prompt_text.startswith(p)) - # output.prompt_token_ids contains original prompt + tokens generated before pause - # output.outputs[0].text is what was generated after resuming with new weights - text_before_pause = full_prompt_text[len(original_prompt) :] - text_after_pause = output.outputs[0].text - print(f"Original prompt: {original_prompt!r}") - print(f"Generated before weight change: {text_before_pause!r}") - print(f"Generated after weight change: {text_after_pause!r}") - print("-" * 50) +ray.get(llm.shutdown.remote()) +ray.kill(llm) +ray.kill(train_model) + +llm_v2 = ray.remote( + num_cpus=0, + num_gpus=0, +)(MyLLM).remote( + model=MODEL_NAME_V2, + enforce_eager=True, + max_model_len=8192, + gpu_memory_utilization=0.75, + distributed_executor_backend="ray", + attention_backend="FLASH_ATTN", +) + +val_futures = [ + llm_v2.do_generate.remote( + list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx], + SamplingParams( + temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx + ), + ) + for output, pause_idx in results +] +val_results = ray.get(val_futures) + +all_pass = True +for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)): + expected = list(output.outputs[0].token_ids)[pause_idx:] + actual = list(val_output.outputs[0].token_ids) + match = actual == expected + + if match: + print(f" [PASS] {PROMPTS[i]!r}") + else: + all_pass = False + print(f" [FAIL] {PROMPTS[i]!r}") + print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}") + print(f" V2 vLLM: {tokenizer.decode(actual)!r}") + for j, (e, a) in enumerate(zip(expected, actual)): + if e != a: + print( + f" first divergence at output token {j}: " + f"expected {e} ({tokenizer.decode([e])!r}) vs " + f"actual {a} ({tokenizer.decode([a])!r})" + ) + break + +ray.get(llm_v2.shutdown.remote()) +ray.kill(llm_v2) +assert all_pass, "Some prompts failed validation, see above for details" +print("=" * 50)