diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 2b80937e8..9130026e1 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -1339,6 +1339,7 @@ steps: - tests/v1/entrypoints/openai/test_multi_api_servers.py - tests/v1/shutdown - tests/v1/worker/test_worker_memory_snapshot.py + - examples/offline_inference/new_weight_syncing/ commands: # Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876 # TODO: Remove when the bug is fixed in a future ROCm release @@ -1970,8 +1971,10 @@ steps: - label: Distributed Tests (4 GPUs) # 35min timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi355_4 + optional: true + # grade: Blocking working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -2025,7 +2028,8 @@ steps: - popd # NEW rlhf examples - pushd ../examples/offline_inference/new_weight_syncing - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - popd @@ -2989,8 +2993,10 @@ steps: - label: Distributed Tests (2 GPUs) # 68min timeout_in_minutes: 90 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi355_2 + optional: true + # grade: Blocking working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index e9bc06180..5b72bf159 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -47,12 +47,14 @@ from vllm.distributed.weight_transfer.nccl_engine import ( NCCLWeightTransferInitInfo, NCCLWeightTransferUpdateInfo, ) +from vllm.platforms import current_platform from vllm.utils.network_utils import get_ip, get_open_port from vllm.v1.executor import Executor MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base" MODEL_NAME_V2 = "Qwen/Qwen3-1.7B" PAUSE_TOKEN_THRESHOLD = 10 +ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN" class MyLLM(vllm.AsyncLLMEngine): @@ -116,10 +118,16 @@ class TrainModel: from vllm.model_executor.layers.batch_invariant import ( init_batch_invariance, ) + from vllm.platforms import current_platform from vllm.v1.attention.backends.registry import AttentionBackendEnum # need to init all env vars for batch invariance which affect nccl ops - init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) + attn_backend = ( + AttentionBackendEnum.TRITON_ATTN + if current_platform.is_rocm() + else AttentionBackendEnum.FLASH_ATTN + ) + init_batch_invariance(attn_backend) self.model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16 @@ -175,23 +183,48 @@ class TrainModel: return new_token_ids -ray.init( - runtime_env={ - "env_vars": { - # enable batch invariance for deterministic outputs - "VLLM_BATCH_INVARIANT": "1", - # prevent ray from setting CUDA_VISIBLE_DEVICES - "RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1", - } - } -) +# Build platform-specific env vars for Ray +ray_env_vars = { + # Prevent Ray from setting CUDA_VISIBLE_DEVICES + "RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1", +} + +if current_platform.is_rocm(): + # For ROCm, BATCH_INVARIANT vllm is not supported + ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0" +else: + # Enable batch invariance for deterministic outputs on NVIDIA + ray_env_vars["VLLM_BATCH_INVARIANT"] = "1" + +ray.init(runtime_env={"env_vars": ray_env_vars}) # Launch the training model actor. Ray's resource scheduler will allocate # 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. train_model = TrainModel.remote(MODEL_NAME_V2) -# Launch the vLLM inference engine. The `enforce_eager` flag reduces -# start-up latency. +rocm_determinism_kwargs = {} +if current_platform.is_rocm(): + # ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and + # sequential request processing (max_num_seqs=1). + rocm_determinism_kwargs = { + "seed": 0, + "enable_prefix_caching": False, + "max_num_seqs": 1, + } + +# Build platform-specific LLM kwargs +llm_kwargs = dict( + model=MODEL_NAME_V1, + enforce_eager=True, + max_model_len=8192, + distributed_executor_backend="ray", + attention_backend=ATTN_BACKEND, + gpu_memory_utilization=0.75, + weight_transfer_config=WeightTransferConfig(backend="nccl"), +) +llm_kwargs.update(rocm_determinism_kwargs) + +# Launch the vLLM inference engine. # With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates # its own placement groups internally for each DP rank, so we must NOT # create an outer placement group (it would reserve GPUs and hide them @@ -199,15 +232,7 @@ train_model = TrainModel.remote(MODEL_NAME_V2) llm = ray.remote( num_cpus=0, num_gpus=0, -)(MyLLM).remote( - model=MODEL_NAME_V1, - enforce_eager=True, - max_model_len=8192, - distributed_executor_backend="ray", - attention_backend="FLASH_ATTN", - gpu_memory_utilization=0.75, - weight_transfer_config=WeightTransferConfig(backend="nccl"), -) +)(MyLLM).remote(**llm_kwargs) PROMPTS = [ "The president of the United States is", @@ -304,25 +329,42 @@ for i, (output, pause_idx) in enumerate(results): print(f" New weights ({n_after} tokens): {after_text!r}") # ── Phase 2: validate with a fresh V2 vLLM instance ──────────────── +# This validation relies on batch-invariant (deterministic) generation to +# compare outputs from the weight-synced engine against a fresh V2 instance. +# On NVIDIA, batch invariance is fully supported, so we require 100% exact +# token match. On ROCm, batch invariance is not yet fully implemented +# (see https://github.com/vllm-project/vllm/issues/27433 and +# https://github.com/vllm-project/vllm/issues/33123), so residual +# non-determinism (e.g. GEMM accumulation order, missing kernel overrides) +# can cause single-token divergences that don't indicate a weight-sync +# failure. We relax the pass rate to 90% on ROCm to accommodate this; a +# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+. +MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9 + print(f"\n{'=' * 50}") print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance") +if current_platform.is_rocm(): + print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)") print(f"{'=' * 50}") ray.get(llm.shutdown.remote()) ray.kill(llm) ray.kill(train_model) -llm_v2 = ray.remote( - num_cpus=0, - num_gpus=0, -)(MyLLM).remote( +llm_v2_kwargs = dict( model=MODEL_NAME_V2, enforce_eager=True, max_model_len=8192, gpu_memory_utilization=0.75, distributed_executor_backend="ray", - attention_backend="FLASH_ATTN", + attention_backend=ATTN_BACKEND, ) +llm_v2_kwargs.update(rocm_determinism_kwargs) + +llm_v2 = ray.remote( + num_cpus=0, + num_gpus=0, +)(MyLLM).remote(**llm_v2_kwargs) val_futures = [ llm_v2.do_generate.remote( @@ -335,16 +377,17 @@ val_futures = [ ] val_results = ray.get(val_futures) -all_pass = True +num_pass = 0 +num_total = len(results) for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)): expected = list(output.outputs[0].token_ids)[pause_idx:] actual = list(val_output.outputs[0].token_ids) match = actual == expected if match: + num_pass += 1 print(f" [PASS] {PROMPTS[i]!r}") else: - all_pass = False print(f" [FAIL] {PROMPTS[i]!r}") print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}") print(f" V2 vLLM: {tokenizer.decode(actual)!r}") @@ -359,5 +402,14 @@ for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_resu ray.get(llm_v2.shutdown.remote()) ray.kill(llm_v2) -assert all_pass, "Some prompts failed validation, see above for details" + +pass_rate = num_pass / num_total +print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})") +print(f" Required: >= {MIN_PASS_RATE:.0%}") + +assert pass_rate >= MIN_PASS_RATE, ( + f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) " + f"is below the required {MIN_PASS_RATE:.0%} threshold. " + f"See failures above for details." +) print("=" * 50)