2026-02-05 09:13:23 -08:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
"""
|
|
|
|
|
Demonstrates async reinforcement learning using vLLM and Ray,
|
2026-03-19 12:46:07 -07:00
|
|
|
with native weight syncing APIs and batch-invariant generation.
|
2026-02-05 09:13:23 -08:00
|
|
|
|
|
|
|
|
The script separates training and inference workloads onto distinct GPUs
|
|
|
|
|
so that Ray can manage process placement and inter-process communication.
|
2026-03-19 12:46:07 -07:00
|
|
|
A Hugging Face Transformer model occupies one GPU for training, and a
|
|
|
|
|
vLLM AsyncLLMEngine occupies another GPU for inference.
|
|
|
|
|
|
|
|
|
|
Batch invariance is enabled so that generation output is deterministic
|
|
|
|
|
regardless of how many requests are batched together. This is required
|
|
|
|
|
for the validation phase to succeed. Batch invariance currently requires
|
|
|
|
|
NVIDIA GPUs with compute capability 9.0 or higher:
|
|
|
|
|
- H-series: H100, H200
|
|
|
|
|
- B-series: B100, B200
|
2026-02-05 09:13:23 -08:00
|
|
|
|
|
|
|
|
The example performs the following steps:
|
2026-03-19 12:46:07 -07:00
|
|
|
* Load the training model (Qwen3-1.7B) on one GPU via a Ray actor.
|
|
|
|
|
* Initialize the inference engine with a base model (Qwen3-1.7B-Base)
|
|
|
|
|
on a separate GPU using vLLM's AsyncLLMEngine with Ray as the
|
|
|
|
|
distributed executor backend.
|
|
|
|
|
* Set up an NCCL-based weight transfer channel between the trainer
|
|
|
|
|
and the inference engine.
|
|
|
|
|
* Submit generation requests for a batch of prompts.
|
|
|
|
|
* Pause generation once any request reaches a token threshold.
|
|
|
|
|
* Broadcast the training model's weights to the inference engine
|
|
|
|
|
via the NCCL weight transfer engine, replacing the base weights.
|
|
|
|
|
* Resume generation and collect results, noting which tokens were
|
|
|
|
|
generated before vs. after the weight swap.
|
|
|
|
|
* Validate correctness by launching a fresh vLLM instance loaded
|
|
|
|
|
directly with the training model and comparing its output to the
|
|
|
|
|
post-swap tokens from the weight-synced engine.
|
2026-02-05 09:13:23 -08:00
|
|
|
|
2026-03-19 12:46:07 -07:00
|
|
|
This example assumes a single-node cluster with two GPUs, but Ray
|
2026-02-05 09:13:23 -08:00
|
|
|
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
|
|
|
|
|
workloads. Residual GPU activity interferes with vLLM memory profiling and
|
|
|
|
|
causes unexpected behavior.
|
|
|
|
|
"""
|
|
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
import asyncio
|
2026-02-05 09:13:23 -08:00
|
|
|
import uuid
|
|
|
|
|
from dataclasses import asdict
|
|
|
|
|
|
|
|
|
|
import ray
|
|
|
|
|
import torch
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
import vllm
|
|
|
|
|
from vllm import SamplingParams
|
|
|
|
|
from vllm.config import WeightTransferConfig
|
|
|
|
|
from vllm.distributed.weight_transfer.base import (
|
|
|
|
|
WeightTransferInitRequest,
|
|
|
|
|
WeightTransferUpdateRequest,
|
|
|
|
|
)
|
|
|
|
|
from vllm.distributed.weight_transfer.nccl_engine import (
|
2026-02-27 12:45:21 -08:00
|
|
|
NCCLTrainerSendWeightsArgs,
|
2026-02-05 09:13:23 -08:00
|
|
|
NCCLWeightTransferEngine,
|
|
|
|
|
NCCLWeightTransferInitInfo,
|
|
|
|
|
NCCLWeightTransferUpdateInfo,
|
|
|
|
|
)
|
2026-03-03 19:44:14 -06:00
|
|
|
from vllm.platforms import current_platform
|
2026-02-05 09:13:23 -08:00
|
|
|
from vllm.utils.network_utils import get_ip, get_open_port
|
|
|
|
|
from vllm.v1.executor import Executor
|
|
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
|
|
|
|
|
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
|
|
|
|
|
PAUSE_TOKEN_THRESHOLD = 10
|
2026-03-03 19:44:14 -06:00
|
|
|
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
|
2026-02-05 09:13:23 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class MyLLM(vllm.AsyncLLMEngine):
|
|
|
|
|
"""Configure the vLLM worker for Ray placement group execution."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
|
engine_args = vllm.AsyncEngineArgs(**kwargs)
|
|
|
|
|
vllm_config = engine_args.create_engine_config()
|
|
|
|
|
executor_class = Executor.get_class(vllm_config)
|
|
|
|
|
super().__init__(
|
|
|
|
|
vllm_config=vllm_config,
|
|
|
|
|
executor_class=executor_class,
|
|
|
|
|
log_requests=engine_args.enable_log_requests,
|
|
|
|
|
log_stats=not engine_args.disable_log_stats,
|
|
|
|
|
)
|
2026-02-23 13:30:56 -08:00
|
|
|
self._generation_paused = False
|
|
|
|
|
self._request_pause_flag = False
|
2026-02-05 09:13:23 -08:00
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
async def do_generate(
|
2026-02-05 09:13:23 -08:00
|
|
|
self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
|
2026-02-23 13:30:56 -08:00
|
|
|
) -> tuple[vllm.RequestOutput, int]:
|
|
|
|
|
"""Generate a single request, setting the request pause flag once the
|
|
|
|
|
token count reaches the threshold.
|
|
|
|
|
|
|
|
|
|
Returns (output, pause_token_index). pause_token_index is the number
|
|
|
|
|
of tokens generated before the weight change, or -1 if no pause.
|
|
|
|
|
"""
|
|
|
|
|
pause_token_index = -1
|
|
|
|
|
prev_token_count = 0
|
|
|
|
|
async for request_output in self.generate(
|
|
|
|
|
{"prompt_token_ids": prompt_token_ids},
|
|
|
|
|
sampling_params,
|
|
|
|
|
request_id=str(uuid.uuid4()),
|
|
|
|
|
):
|
|
|
|
|
output = request_output
|
|
|
|
|
cur_token_count = len(output.outputs[0].token_ids)
|
|
|
|
|
if (
|
|
|
|
|
cur_token_count >= PAUSE_TOKEN_THRESHOLD
|
|
|
|
|
and not self._request_pause_flag
|
2026-02-05 09:13:23 -08:00
|
|
|
):
|
2026-02-23 13:30:56 -08:00
|
|
|
self._request_pause_flag = True
|
|
|
|
|
if self._generation_paused and pause_token_index == -1:
|
|
|
|
|
pause_token_index = prev_token_count
|
|
|
|
|
prev_token_count = cur_token_count
|
|
|
|
|
return output, pause_token_index
|
|
|
|
|
|
|
|
|
|
async def pause_after_n_tokens(self):
|
|
|
|
|
"""Wait for any request to set the pause flag, then pause."""
|
|
|
|
|
while not self._request_pause_flag:
|
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
await super().pause_generation(mode="keep")
|
2026-03-02 12:36:40 -08:00
|
|
|
await asyncio.sleep(5)
|
2026-02-23 13:30:56 -08:00
|
|
|
self._generation_paused = True
|
2026-02-05 09:13:23 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@ray.remote(num_gpus=1)
|
|
|
|
|
class TrainModel:
|
|
|
|
|
"""Ray actor that wraps the training model on a dedicated GPU."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_name: str):
|
2026-02-23 13:30:56 -08:00
|
|
|
from vllm.model_executor.layers.batch_invariant import (
|
|
|
|
|
init_batch_invariance,
|
|
|
|
|
)
|
2026-03-03 19:44:14 -06:00
|
|
|
from vllm.platforms import current_platform
|
2026-02-23 13:30:56 -08:00
|
|
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
|
|
|
|
|
|
|
|
|
# need to init all env vars for batch invariance which affect nccl ops
|
2026-03-03 19:44:14 -06:00
|
|
|
attn_backend = (
|
|
|
|
|
AttentionBackendEnum.TRITON_ATTN
|
|
|
|
|
if current_platform.is_rocm()
|
|
|
|
|
else AttentionBackendEnum.FLASH_ATTN
|
|
|
|
|
)
|
|
|
|
|
init_batch_invariance(attn_backend)
|
2026-02-23 13:30:56 -08:00
|
|
|
|
2026-02-05 09:13:23 -08:00
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
model_name, dtype=torch.bfloat16
|
|
|
|
|
).to("cuda:0")
|
|
|
|
|
self.port = get_open_port()
|
|
|
|
|
self.master_address = get_ip()
|
|
|
|
|
|
|
|
|
|
def get_master_address_and_port(self):
|
|
|
|
|
return self.master_address, self.port
|
|
|
|
|
|
|
|
|
|
def get_weight_metadata(self):
|
|
|
|
|
"""Return weight names, dtypes, and shapes for weight transfer."""
|
|
|
|
|
names = []
|
|
|
|
|
dtype_names = []
|
|
|
|
|
shapes = []
|
|
|
|
|
for name, p in self.model.named_parameters():
|
|
|
|
|
names.append(name)
|
|
|
|
|
dtype_names.append(str(p.dtype).split(".")[-1])
|
|
|
|
|
shapes.append(list(p.shape))
|
|
|
|
|
return names, dtype_names, shapes
|
|
|
|
|
|
|
|
|
|
def init_weight_transfer_group(self, world_size):
|
|
|
|
|
"""Initialize the NCCL process group for weight transfer."""
|
|
|
|
|
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
|
|
|
|
|
dict(
|
|
|
|
|
master_address=self.master_address,
|
|
|
|
|
master_port=self.port,
|
|
|
|
|
world_size=world_size,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def broadcast_weights(self, packed: bool = True):
|
|
|
|
|
"""Broadcast weights to the inference engine."""
|
2026-02-27 12:45:21 -08:00
|
|
|
trainer_args = NCCLTrainerSendWeightsArgs(
|
2026-02-05 09:13:23 -08:00
|
|
|
group=self.model_update_group,
|
|
|
|
|
packed=packed,
|
|
|
|
|
)
|
2026-02-27 12:45:21 -08:00
|
|
|
NCCLWeightTransferEngine.trainer_send_weights(
|
|
|
|
|
iterator=self.model.named_parameters(),
|
|
|
|
|
trainer_args=trainer_args,
|
|
|
|
|
)
|
2026-02-05 09:13:23 -08:00
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
@torch.inference_mode()
|
|
|
|
|
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
|
|
|
|
|
"""Greedy-decode max_new_tokens from the given context."""
|
|
|
|
|
input_ids = torch.tensor([token_ids], device="cuda:0")
|
|
|
|
|
output = self.model.generate(
|
|
|
|
|
input_ids,
|
|
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
|
|
do_sample=False,
|
|
|
|
|
)
|
|
|
|
|
new_token_ids = output[0, len(token_ids) :].tolist()
|
|
|
|
|
return new_token_ids
|
|
|
|
|
|
|
|
|
|
|
2026-03-03 19:44:14 -06:00
|
|
|
# Build platform-specific env vars for Ray
|
|
|
|
|
ray_env_vars = {
|
|
|
|
|
# Prevent Ray from setting CUDA_VISIBLE_DEVICES
|
|
|
|
|
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if current_platform.is_rocm():
|
|
|
|
|
# For ROCm, BATCH_INVARIANT vllm is not supported
|
|
|
|
|
ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
|
|
|
|
|
else:
|
|
|
|
|
# Enable batch invariance for deterministic outputs on NVIDIA
|
|
|
|
|
ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"
|
|
|
|
|
|
|
|
|
|
ray.init(runtime_env={"env_vars": ray_env_vars})
|
2026-02-05 09:13:23 -08:00
|
|
|
|
|
|
|
|
# Launch the training model actor. Ray's resource scheduler will allocate
|
|
|
|
|
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
|
2026-02-23 13:30:56 -08:00
|
|
|
train_model = TrainModel.remote(MODEL_NAME_V2)
|
2026-02-05 09:13:23 -08:00
|
|
|
|
2026-03-03 19:44:14 -06:00
|
|
|
rocm_determinism_kwargs = {}
|
|
|
|
|
if current_platform.is_rocm():
|
|
|
|
|
# ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
|
|
|
|
|
# sequential request processing (max_num_seqs=1).
|
|
|
|
|
rocm_determinism_kwargs = {
|
|
|
|
|
"seed": 0,
|
|
|
|
|
"enable_prefix_caching": False,
|
|
|
|
|
"max_num_seqs": 1,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Build platform-specific LLM kwargs
|
|
|
|
|
llm_kwargs = dict(
|
2026-02-23 13:30:56 -08:00
|
|
|
model=MODEL_NAME_V1,
|
2026-02-05 09:13:23 -08:00
|
|
|
enforce_eager=True,
|
2026-02-23 13:30:56 -08:00
|
|
|
max_model_len=8192,
|
2026-02-05 09:13:23 -08:00
|
|
|
distributed_executor_backend="ray",
|
2026-03-03 19:44:14 -06:00
|
|
|
attention_backend=ATTN_BACKEND,
|
2026-02-23 13:30:56 -08:00
|
|
|
gpu_memory_utilization=0.75,
|
2026-02-05 09:13:23 -08:00
|
|
|
weight_transfer_config=WeightTransferConfig(backend="nccl"),
|
|
|
|
|
)
|
2026-03-03 19:44:14 -06:00
|
|
|
llm_kwargs.update(rocm_determinism_kwargs)
|
|
|
|
|
|
|
|
|
|
# Launch the vLLM inference engine.
|
|
|
|
|
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
|
|
|
|
|
# its own placement groups internally for each DP rank, so we must NOT
|
|
|
|
|
# create an outer placement group (it would reserve GPUs and hide them
|
|
|
|
|
# from the internal DP resource check).
|
|
|
|
|
llm = ray.remote(
|
|
|
|
|
num_cpus=0,
|
|
|
|
|
num_gpus=0,
|
|
|
|
|
)(MyLLM).remote(**llm_kwargs)
|
2026-02-05 09:13:23 -08:00
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
PROMPTS = [
|
2026-02-05 09:13:23 -08:00
|
|
|
"The president of the United States is",
|
|
|
|
|
"The capital of France is",
|
2026-02-23 13:30:56 -08:00
|
|
|
"The largest ocean on Earth is",
|
|
|
|
|
"The speed of light in a vacuum is",
|
|
|
|
|
"The chemical formula for water is",
|
|
|
|
|
"The tallest mountain in the world is",
|
|
|
|
|
"The first person to walk on the moon was",
|
|
|
|
|
"The Great Wall of China was built to",
|
|
|
|
|
"Photosynthesis is the process by which",
|
|
|
|
|
"The theory of general relativity was proposed by",
|
|
|
|
|
"The boiling point of water at sea level is",
|
|
|
|
|
"The largest planet in our solar system is",
|
|
|
|
|
"DNA stands for deoxyribonucleic acid and it",
|
2026-02-05 09:13:23 -08:00
|
|
|
]
|
|
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
|
|
|
|
|
batch_prompt_token_ids = [
|
|
|
|
|
tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
|
2026-02-05 09:13:23 -08:00
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Set up the communication channel between the training process and the
|
|
|
|
|
# inference engine.
|
|
|
|
|
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
|
|
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
world_size = 2 # 1 trainer + 1 inference worker
|
2026-02-05 09:13:23 -08:00
|
|
|
inference_handle = llm.init_weight_transfer_engine.remote(
|
|
|
|
|
WeightTransferInitRequest(
|
|
|
|
|
init_info=asdict(
|
|
|
|
|
NCCLWeightTransferInitInfo(
|
|
|
|
|
master_address=master_address,
|
|
|
|
|
master_port=master_port,
|
|
|
|
|
rank_offset=1,
|
|
|
|
|
world_size=world_size,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Initialize weight transfer group on both the training actor and inference engine
|
|
|
|
|
train_handle = train_model.init_weight_transfer_group.remote(world_size)
|
|
|
|
|
ray.get([train_handle, inference_handle])
|
|
|
|
|
|
|
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
N_NEW_TOKENS = 100
|
2026-02-05 09:13:23 -08:00
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
# Collect weight metadata once
|
|
|
|
|
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
|
2026-02-05 09:13:23 -08:00
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
# ── Phase 1: concurrent requests with weight sync ───────────────────
|
|
|
|
|
print(f"\n{'=' * 50}")
|
|
|
|
|
print(f"Prompts ({len(PROMPTS)}):")
|
|
|
|
|
for p in PROMPTS:
|
|
|
|
|
print(f" - {p!r}")
|
|
|
|
|
print(f"{'=' * 50}")
|
2026-02-05 09:13:23 -08:00
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
sampling_params = SamplingParams(
|
|
|
|
|
temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
gen_futures = [
|
|
|
|
|
llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
ray.get(llm.pause_after_n_tokens.remote())
|
2026-02-05 09:13:23 -08:00
|
|
|
|
|
|
|
|
inference_handle = llm.update_weights.remote(
|
|
|
|
|
WeightTransferUpdateRequest(
|
|
|
|
|
update_info=asdict(
|
|
|
|
|
NCCLWeightTransferUpdateInfo(
|
|
|
|
|
names=names,
|
|
|
|
|
dtype_names=dtype_names,
|
|
|
|
|
shapes=shapes,
|
|
|
|
|
packed=True,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
train_handle = train_model.broadcast_weights.remote(packed=True)
|
|
|
|
|
ray.get([train_handle, inference_handle])
|
|
|
|
|
|
|
|
|
|
ray.get(llm.resume_generation.remote())
|
2026-02-23 13:30:56 -08:00
|
|
|
results = ray.get(gen_futures)
|
|
|
|
|
|
|
|
|
|
for i, (output, pause_idx) in enumerate(results):
|
|
|
|
|
all_token_ids = list(output.outputs[0].token_ids)
|
|
|
|
|
before_text = tokenizer.decode(all_token_ids[:pause_idx])
|
|
|
|
|
after_text = tokenizer.decode(all_token_ids[pause_idx:])
|
|
|
|
|
print(f"\n Request {i} ({PROMPTS[i]!r}):")
|
|
|
|
|
print(f" Old weights ({pause_idx} tokens): {before_text!r}")
|
|
|
|
|
n_after = len(all_token_ids) - pause_idx
|
|
|
|
|
print(f" New weights ({n_after} tokens): {after_text!r}")
|
|
|
|
|
|
|
|
|
|
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
|
2026-03-03 19:44:14 -06:00
|
|
|
# This validation relies on batch-invariant (deterministic) generation to
|
|
|
|
|
# compare outputs from the weight-synced engine against a fresh V2 instance.
|
|
|
|
|
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
|
|
|
|
|
# token match. On ROCm, batch invariance is not yet fully implemented
|
|
|
|
|
# (see https://github.com/vllm-project/vllm/issues/27433 and
|
|
|
|
|
# https://github.com/vllm-project/vllm/issues/33123), so residual
|
|
|
|
|
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
|
|
|
|
|
# can cause single-token divergences that don't indicate a weight-sync
|
|
|
|
|
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
|
|
|
|
|
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
|
|
|
|
|
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9
|
|
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
print(f"\n{'=' * 50}")
|
|
|
|
|
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
|
2026-03-03 19:44:14 -06:00
|
|
|
if current_platform.is_rocm():
|
|
|
|
|
print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
|
2026-02-23 13:30:56 -08:00
|
|
|
print(f"{'=' * 50}")
|
|
|
|
|
|
|
|
|
|
ray.get(llm.shutdown.remote())
|
|
|
|
|
ray.kill(llm)
|
|
|
|
|
ray.kill(train_model)
|
|
|
|
|
|
2026-03-03 19:44:14 -06:00
|
|
|
llm_v2_kwargs = dict(
|
2026-02-23 13:30:56 -08:00
|
|
|
model=MODEL_NAME_V2,
|
|
|
|
|
enforce_eager=True,
|
|
|
|
|
max_model_len=8192,
|
|
|
|
|
gpu_memory_utilization=0.75,
|
|
|
|
|
distributed_executor_backend="ray",
|
2026-03-03 19:44:14 -06:00
|
|
|
attention_backend=ATTN_BACKEND,
|
2026-02-23 13:30:56 -08:00
|
|
|
)
|
2026-03-03 19:44:14 -06:00
|
|
|
llm_v2_kwargs.update(rocm_determinism_kwargs)
|
|
|
|
|
|
|
|
|
|
llm_v2 = ray.remote(
|
|
|
|
|
num_cpus=0,
|
|
|
|
|
num_gpus=0,
|
|
|
|
|
)(MyLLM).remote(**llm_v2_kwargs)
|
2026-02-23 13:30:56 -08:00
|
|
|
|
|
|
|
|
val_futures = [
|
|
|
|
|
llm_v2.do_generate.remote(
|
|
|
|
|
list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
|
|
|
|
|
SamplingParams(
|
|
|
|
|
temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
for output, pause_idx in results
|
|
|
|
|
]
|
|
|
|
|
val_results = ray.get(val_futures)
|
|
|
|
|
|
2026-03-03 19:44:14 -06:00
|
|
|
num_pass = 0
|
|
|
|
|
num_total = len(results)
|
2026-02-23 13:30:56 -08:00
|
|
|
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
|
|
|
|
|
expected = list(output.outputs[0].token_ids)[pause_idx:]
|
|
|
|
|
actual = list(val_output.outputs[0].token_ids)
|
|
|
|
|
match = actual == expected
|
|
|
|
|
|
|
|
|
|
if match:
|
2026-03-03 19:44:14 -06:00
|
|
|
num_pass += 1
|
2026-02-23 13:30:56 -08:00
|
|
|
print(f" [PASS] {PROMPTS[i]!r}")
|
|
|
|
|
else:
|
|
|
|
|
print(f" [FAIL] {PROMPTS[i]!r}")
|
|
|
|
|
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
|
|
|
|
|
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
|
|
|
|
|
for j, (e, a) in enumerate(zip(expected, actual)):
|
|
|
|
|
if e != a:
|
|
|
|
|
print(
|
|
|
|
|
f" first divergence at output token {j}: "
|
|
|
|
|
f"expected {e} ({tokenizer.decode([e])!r}) vs "
|
|
|
|
|
f"actual {a} ({tokenizer.decode([a])!r})"
|
|
|
|
|
)
|
|
|
|
|
break
|
2026-02-05 09:13:23 -08:00
|
|
|
|
2026-02-23 13:30:56 -08:00
|
|
|
ray.get(llm_v2.shutdown.remote())
|
|
|
|
|
ray.kill(llm_v2)
|
2026-03-03 19:44:14 -06:00
|
|
|
|
|
|
|
|
pass_rate = num_pass / num_total
|
|
|
|
|
print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
|
|
|
|
|
print(f" Required: >= {MIN_PASS_RATE:.0%}")
|
|
|
|
|
|
|
|
|
|
assert pass_rate >= MIN_PASS_RATE, (
|
|
|
|
|
f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
|
|
|
|
|
f"is below the required {MIN_PASS_RATE:.0%} threshold. "
|
|
|
|
|
f"See failures above for details."
|
|
|
|
|
)
|
2026-02-23 13:30:56 -08:00
|
|
|
print("=" * 50)
|