[ROCm][CI] Support async weight transfer example with platform-aware determinism (#35710)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -1339,6 +1339,7 @@ steps:
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- tests/v1/shutdown
|
||||
- tests/v1/worker/test_worker_memory_snapshot.py
|
||||
- examples/offline_inference/new_weight_syncing/
|
||||
commands:
|
||||
# Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876
|
||||
# TODO: Remove when the bug is fixed in a future ROCm release
|
||||
@@ -1970,8 +1971,10 @@ steps:
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 35min
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi355_4
|
||||
optional: true
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
@@ -2025,7 +2028,8 @@ steps:
|
||||
- popd
|
||||
# NEW rlhf examples
|
||||
- pushd ../examples/offline_inference/new_weight_syncing
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
|
||||
- popd
|
||||
|
||||
@@ -2989,8 +2993,10 @@ steps:
|
||||
|
||||
- label: Distributed Tests (2 GPUs) # 68min
|
||||
timeout_in_minutes: 90
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi355_2
|
||||
optional: true
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
|
||||
@@ -47,12 +47,14 @@ from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
NCCLWeightTransferInitInfo,
|
||||
NCCLWeightTransferUpdateInfo,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
from vllm.v1.executor import Executor
|
||||
|
||||
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
|
||||
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
|
||||
PAUSE_TOKEN_THRESHOLD = 10
|
||||
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
|
||||
|
||||
|
||||
class MyLLM(vllm.AsyncLLMEngine):
|
||||
@@ -116,10 +118,16 @@ class TrainModel:
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
init_batch_invariance,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
# need to init all env vars for batch invariance which affect nccl ops
|
||||
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN)
|
||||
attn_backend = (
|
||||
AttentionBackendEnum.TRITON_ATTN
|
||||
if current_platform.is_rocm()
|
||||
else AttentionBackendEnum.FLASH_ATTN
|
||||
)
|
||||
init_batch_invariance(attn_backend)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, dtype=torch.bfloat16
|
||||
@@ -175,23 +183,48 @@ class TrainModel:
|
||||
return new_token_ids
|
||||
|
||||
|
||||
ray.init(
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
# enable batch invariance for deterministic outputs
|
||||
"VLLM_BATCH_INVARIANT": "1",
|
||||
# prevent ray from setting CUDA_VISIBLE_DEVICES
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
|
||||
}
|
||||
}
|
||||
)
|
||||
# Build platform-specific env vars for Ray
|
||||
ray_env_vars = {
|
||||
# Prevent Ray from setting CUDA_VISIBLE_DEVICES
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
|
||||
}
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# For ROCm, BATCH_INVARIANT vllm is not supported
|
||||
ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
|
||||
else:
|
||||
# Enable batch invariance for deterministic outputs on NVIDIA
|
||||
ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"
|
||||
|
||||
ray.init(runtime_env={"env_vars": ray_env_vars})
|
||||
|
||||
# Launch the training model actor. Ray's resource scheduler will allocate
|
||||
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
|
||||
train_model = TrainModel.remote(MODEL_NAME_V2)
|
||||
|
||||
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
||||
# start-up latency.
|
||||
rocm_determinism_kwargs = {}
|
||||
if current_platform.is_rocm():
|
||||
# ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
|
||||
# sequential request processing (max_num_seqs=1).
|
||||
rocm_determinism_kwargs = {
|
||||
"seed": 0,
|
||||
"enable_prefix_caching": False,
|
||||
"max_num_seqs": 1,
|
||||
}
|
||||
|
||||
# Build platform-specific LLM kwargs
|
||||
llm_kwargs = dict(
|
||||
model=MODEL_NAME_V1,
|
||||
enforce_eager=True,
|
||||
max_model_len=8192,
|
||||
distributed_executor_backend="ray",
|
||||
attention_backend=ATTN_BACKEND,
|
||||
gpu_memory_utilization=0.75,
|
||||
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
||||
)
|
||||
llm_kwargs.update(rocm_determinism_kwargs)
|
||||
|
||||
# Launch the vLLM inference engine.
|
||||
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
|
||||
# its own placement groups internally for each DP rank, so we must NOT
|
||||
# create an outer placement group (it would reserve GPUs and hide them
|
||||
@@ -199,15 +232,7 @@ train_model = TrainModel.remote(MODEL_NAME_V2)
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
)(MyLLM).remote(
|
||||
model=MODEL_NAME_V1,
|
||||
enforce_eager=True,
|
||||
max_model_len=8192,
|
||||
distributed_executor_backend="ray",
|
||||
attention_backend="FLASH_ATTN",
|
||||
gpu_memory_utilization=0.75,
|
||||
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
||||
)
|
||||
)(MyLLM).remote(**llm_kwargs)
|
||||
|
||||
PROMPTS = [
|
||||
"The president of the United States is",
|
||||
@@ -304,25 +329,42 @@ for i, (output, pause_idx) in enumerate(results):
|
||||
print(f" New weights ({n_after} tokens): {after_text!r}")
|
||||
|
||||
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
|
||||
# This validation relies on batch-invariant (deterministic) generation to
|
||||
# compare outputs from the weight-synced engine against a fresh V2 instance.
|
||||
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
|
||||
# token match. On ROCm, batch invariance is not yet fully implemented
|
||||
# (see https://github.com/vllm-project/vllm/issues/27433 and
|
||||
# https://github.com/vllm-project/vllm/issues/33123), so residual
|
||||
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
|
||||
# can cause single-token divergences that don't indicate a weight-sync
|
||||
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
|
||||
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
|
||||
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9
|
||||
|
||||
print(f"\n{'=' * 50}")
|
||||
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
|
||||
if current_platform.is_rocm():
|
||||
print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
|
||||
print(f"{'=' * 50}")
|
||||
|
||||
ray.get(llm.shutdown.remote())
|
||||
ray.kill(llm)
|
||||
ray.kill(train_model)
|
||||
|
||||
llm_v2 = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
)(MyLLM).remote(
|
||||
llm_v2_kwargs = dict(
|
||||
model=MODEL_NAME_V2,
|
||||
enforce_eager=True,
|
||||
max_model_len=8192,
|
||||
gpu_memory_utilization=0.75,
|
||||
distributed_executor_backend="ray",
|
||||
attention_backend="FLASH_ATTN",
|
||||
attention_backend=ATTN_BACKEND,
|
||||
)
|
||||
llm_v2_kwargs.update(rocm_determinism_kwargs)
|
||||
|
||||
llm_v2 = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
)(MyLLM).remote(**llm_v2_kwargs)
|
||||
|
||||
val_futures = [
|
||||
llm_v2.do_generate.remote(
|
||||
@@ -335,16 +377,17 @@ val_futures = [
|
||||
]
|
||||
val_results = ray.get(val_futures)
|
||||
|
||||
all_pass = True
|
||||
num_pass = 0
|
||||
num_total = len(results)
|
||||
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
|
||||
expected = list(output.outputs[0].token_ids)[pause_idx:]
|
||||
actual = list(val_output.outputs[0].token_ids)
|
||||
match = actual == expected
|
||||
|
||||
if match:
|
||||
num_pass += 1
|
||||
print(f" [PASS] {PROMPTS[i]!r}")
|
||||
else:
|
||||
all_pass = False
|
||||
print(f" [FAIL] {PROMPTS[i]!r}")
|
||||
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
|
||||
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
|
||||
@@ -359,5 +402,14 @@ for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_resu
|
||||
|
||||
ray.get(llm_v2.shutdown.remote())
|
||||
ray.kill(llm_v2)
|
||||
assert all_pass, "Some prompts failed validation, see above for details"
|
||||
|
||||
pass_rate = num_pass / num_total
|
||||
print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
|
||||
print(f" Required: >= {MIN_PASS_RATE:.0%}")
|
||||
|
||||
assert pass_rate >= MIN_PASS_RATE, (
|
||||
f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
|
||||
f"is below the required {MIN_PASS_RATE:.0%} threshold. "
|
||||
f"See failures above for details."
|
||||
)
|
||||
print("=" * 50)
|
||||
|
||||
Reference in New Issue
Block a user