[ROCm][CI] Support async weight transfer example with platform-aware determinism (#35710)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-03 19:44:14 -06:00
committed by GitHub
parent f22ff2958c
commit f7da9cdffc
2 changed files with 91 additions and 33 deletions

View File

@@ -1339,6 +1339,7 @@ steps:
- tests/v1/entrypoints/openai/test_multi_api_servers.py
- tests/v1/shutdown
- tests/v1/worker/test_worker_memory_snapshot.py
- examples/offline_inference/new_weight_syncing/
commands:
# Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876
# TODO: Remove when the bug is fixed in a future ROCm release
@@ -1970,8 +1971,10 @@ steps:
- label: Distributed Tests (4 GPUs) # 35min
timeout_in_minutes: 50
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi355_4
optional: true
# grade: Blocking
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@@ -2025,7 +2028,8 @@ steps:
- popd
# NEW rlhf examples
- pushd ../examples/offline_inference/new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- popd
@@ -2989,8 +2993,10 @@ steps:
- label: Distributed Tests (2 GPUs) # 68min
timeout_in_minutes: 90
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi355_2
optional: true
# grade: Blocking
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:

View File

@@ -47,12 +47,14 @@ from vllm.distributed.weight_transfer.nccl_engine import (
NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD = 10
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
class MyLLM(vllm.AsyncLLMEngine):
@@ -116,10 +118,16 @@ class TrainModel:
from vllm.model_executor.layers.batch_invariant import (
init_batch_invariance,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
# need to init all env vars for batch invariance which affect nccl ops
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN)
attn_backend = (
AttentionBackendEnum.TRITON_ATTN
if current_platform.is_rocm()
else AttentionBackendEnum.FLASH_ATTN
)
init_batch_invariance(attn_backend)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16
@@ -175,23 +183,48 @@ class TrainModel:
return new_token_ids
ray.init(
runtime_env={
"env_vars": {
# enable batch invariance for deterministic outputs
"VLLM_BATCH_INVARIANT": "1",
# prevent ray from setting CUDA_VISIBLE_DEVICES
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
}
}
)
# Build platform-specific env vars for Ray
ray_env_vars = {
# Prevent Ray from setting CUDA_VISIBLE_DEVICES
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
}
if current_platform.is_rocm():
# For ROCm, BATCH_INVARIANT vllm is not supported
ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
else:
# Enable batch invariance for deterministic outputs on NVIDIA
ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"
ray.init(runtime_env={"env_vars": ray_env_vars})
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME_V2)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
rocm_determinism_kwargs = {}
if current_platform.is_rocm():
# ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
# sequential request processing (max_num_seqs=1).
rocm_determinism_kwargs = {
"seed": 0,
"enable_prefix_caching": False,
"max_num_seqs": 1,
}
# Build platform-specific LLM kwargs
llm_kwargs = dict(
model=MODEL_NAME_V1,
enforce_eager=True,
max_model_len=8192,
distributed_executor_backend="ray",
attention_backend=ATTN_BACKEND,
gpu_memory_utilization=0.75,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
llm_kwargs.update(rocm_determinism_kwargs)
# Launch the vLLM inference engine.
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
@@ -199,15 +232,7 @@ train_model = TrainModel.remote(MODEL_NAME_V2)
llm = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(
model=MODEL_NAME_V1,
enforce_eager=True,
max_model_len=8192,
distributed_executor_backend="ray",
attention_backend="FLASH_ATTN",
gpu_memory_utilization=0.75,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
)(MyLLM).remote(**llm_kwargs)
PROMPTS = [
"The president of the United States is",
@@ -304,25 +329,42 @@ for i, (output, pause_idx) in enumerate(results):
print(f" New weights ({n_after} tokens): {after_text!r}")
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
# This validation relies on batch-invariant (deterministic) generation to
# compare outputs from the weight-synced engine against a fresh V2 instance.
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
# token match. On ROCm, batch invariance is not yet fully implemented
# (see https://github.com/vllm-project/vllm/issues/27433 and
# https://github.com/vllm-project/vllm/issues/33123), so residual
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
# can cause single-token divergences that don't indicate a weight-sync
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9
print(f"\n{'=' * 50}")
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
if current_platform.is_rocm():
print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
print(f"{'=' * 50}")
ray.get(llm.shutdown.remote())
ray.kill(llm)
ray.kill(train_model)
llm_v2 = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(
llm_v2_kwargs = dict(
model=MODEL_NAME_V2,
enforce_eager=True,
max_model_len=8192,
gpu_memory_utilization=0.75,
distributed_executor_backend="ray",
attention_backend="FLASH_ATTN",
attention_backend=ATTN_BACKEND,
)
llm_v2_kwargs.update(rocm_determinism_kwargs)
llm_v2 = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(**llm_v2_kwargs)
val_futures = [
llm_v2.do_generate.remote(
@@ -335,16 +377,17 @@ val_futures = [
]
val_results = ray.get(val_futures)
all_pass = True
num_pass = 0
num_total = len(results)
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
expected = list(output.outputs[0].token_ids)[pause_idx:]
actual = list(val_output.outputs[0].token_ids)
match = actual == expected
if match:
num_pass += 1
print(f" [PASS] {PROMPTS[i]!r}")
else:
all_pass = False
print(f" [FAIL] {PROMPTS[i]!r}")
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
@@ -359,5 +402,14 @@ for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_resu
ray.get(llm_v2.shutdown.remote())
ray.kill(llm_v2)
assert all_pass, "Some prompts failed validation, see above for details"
pass_rate = num_pass / num_total
print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
print(f" Required: >= {MIN_PASS_RATE:.0%}")
assert pass_rate >= MIN_PASS_RATE, (
f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
f"is below the required {MIN_PASS_RATE:.0%} threshold. "
f"See failures above for details."
)
print("=" * 50)