[Spec Decode][KV Connector] Fix KV transfer in PD + speculative decoding (#35158)

Signed-off-by: Claude <noreply@anthropic.com>
Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
zhanqiuhu
2026-03-06 02:50:44 -05:00
committed by GitHub
parent 807d680337
commit 90f3c01fa4
5 changed files with 484 additions and 21 deletions

View File

@@ -0,0 +1,237 @@
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# NixlConnector PD + speculative decoding acceptance length test.
# Tests EAGLE3 acceptance length for both RDMA (cuda) and CPU host (cpu)
# KV buffer device paths.
#
# For each kv_buffer_device setting, starts prefill + decode vllm servers
# with NixlConnector, then runs test_spec_decode_acceptance.py to validate
# acceptance length matches the standalone SD baseline.
#
# Usage:
# CUDA_VISIBLE_DEVICES=0,1 bash tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
#
# Environment variables:
# KV_BUFFER_DEVICES - space-separated list of devices to test
# (default: "cuda cpu")
# SD_METHOD - spec decode method (default: eagle3)
# SD_MODEL - drafter model path
# MODEL_NAME - target model (default: meta-llama/Llama-3.1-8B-Instruct)
# NUM_SPEC_TOKENS - number of speculative tokens (default: 3)
# GPU_MEMORY_UTILIZATION - (default: 0.7)
set -x
# ── Model & spec decode config ──────────────────────────────────────────
MODEL_NAME="${MODEL_NAME:-meta-llama/Llama-3.1-8B-Instruct}"
SD_METHOD="${SD_METHOD:-eagle3}"
SD_MODEL="${SD_MODEL:-RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3}"
NUM_SPEC_TOKENS="${NUM_SPEC_TOKENS:-3}"
MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}"
PREFILL_SPEC_CONFIG="{\"method\":\"${SD_METHOD}\",\"model\":\"${SD_MODEL}\",\"num_speculative_tokens\":1,\"max_model_len\":${MAX_MODEL_LEN}}"
DECODE_SPEC_CONFIG="{\"method\":\"${SD_METHOD}\",\"model\":\"${SD_MODEL}\",\"num_speculative_tokens\":${NUM_SPEC_TOKENS},\"max_model_len\":${MAX_MODEL_LEN}}"
# ── Test matrix ──────────────────────────────────────────────────────────
KV_BUFFER_DEVICES="${KV_BUFFER_DEVICES:-cuda cpu}"
# ── Cluster layout ───────────────────────────────────────────────────────
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1}
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1}
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.7}
BLOCK_SIZE=${BLOCK_SIZE:-16}
GIT_ROOT=$(git rev-parse --show-toplevel)
SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
cleanup_instances() {
echo ""
echo "Cleaning up..."
kill $(jobs -pr) 2>/dev/null || true
sleep 1
kill -9 $(jobs -pr) 2>/dev/null || true
pkill -9 -f "vllm serve.*${MODEL_NAME}" 2>/dev/null || true
pkill -9 -f "toy_proxy_server.*8192" 2>/dev/null || true
sleep 1
echo "Cleanup done."
}
trap cleanup_instances EXIT
trap 'echo " Interrupted."; exit 130' INT TERM
wait_for_server() {
local port=$1
local deadline=600
local elapsed=0
echo "Waiting for server on port ${port}..."
while [ $elapsed -lt $deadline ]; do
if curl -s "localhost:${port}/v1/completions" > /dev/null 2>&1; then
echo "Server on port ${port} ready"
return 0
fi
sleep 2
elapsed=$((elapsed + 2))
done
echo "FAIL: Server on port ${port} did not start within ${deadline}s"
exit 1
}
# ── Resolve GPU list ─────────────────────────────────────────────────────
if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then
IFS=',' read -ra ALL_GPUS <<< "$CUDA_VISIBLE_DEVICES"
else
ALL_GPUS=()
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
num=$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)
elif [[ "$SMI_BIN" == *"rocm"* ]]; then
num=$($SMI_BIN -l | grep -c GPU)
else
num=1
fi
for (( g=0; g<num; g++ )); do ALL_GPUS+=($g); done
fi
TOTAL_GPUS_NEEDED=$(( (NUM_PREFILL_INSTANCES * PREFILLER_TP_SIZE) + (NUM_DECODE_INSTANCES * DECODER_TP_SIZE) ))
if [[ ${#ALL_GPUS[@]} -lt $TOTAL_GPUS_NEEDED ]]; then
echo "FAIL: Need $TOTAL_GPUS_NEEDED GPUs but only have ${#ALL_GPUS[@]} (CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-not set})"
exit 1
fi
# ── Run one test iteration ───────────────────────────────────────────────
run_test_for_device() {
local kv_device=$1
if [[ "$kv_device" == "cuda" ]]; then
local kv_config='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
else
local kv_config="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${kv_device}\"}"
fi
echo ""
echo "================================================================"
echo "NixlConnector PD + Spec Decode Acceptance Test (kv_buffer_device=${kv_device})"
echo "================================================================"
echo "Model: ${MODEL_NAME}"
echo "SD method: ${SD_METHOD}"
echo "SD model: ${SD_MODEL}"
echo "Spec tokens: ${NUM_SPEC_TOKENS}"
echo "KV buffer device: ${kv_device}"
echo "GPUs available: ${ALL_GPUS[*]}"
echo "================================================================"
local PREFILL_HOSTS=()
local PREFILL_PORTS=()
local DECODE_HOSTS=()
local DECODE_PORTS=()
local GPU_IDX=0
# Start prefill instances
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
local GPU_ID="${ALL_GPUS[$GPU_IDX]}"
GPU_IDX=$((GPU_IDX + 1))
for (( j=1; j < PREFILLER_TP_SIZE; j++ )); do
GPU_ID="${GPU_ID},${ALL_GPUS[$GPU_IDX]}"
GPU_IDX=$((GPU_IDX + 1))
done
local PORT=$((8100 + i))
local SIDE_CHANNEL_PORT=$((5559 + i))
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
CUDA_VISIBLE_DEVICES=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $MODEL_NAME \
--port $PORT \
--enforce-eager \
--max-model-len $MAX_MODEL_LEN \
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config "$kv_config" \
--speculative-config "$PREFILL_SPEC_CONFIG" \
--attention-backend FLASH_ATTN &
PREFILL_HOSTS+=("localhost")
PREFILL_PORTS+=("$PORT")
done
# Start decode instances
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
local GPU_ID="${ALL_GPUS[$GPU_IDX]}"
GPU_IDX=$((GPU_IDX + 1))
for (( j=1; j < DECODER_TP_SIZE; j++ )); do
GPU_ID="${GPU_ID},${ALL_GPUS[$GPU_IDX]}"
GPU_IDX=$((GPU_IDX + 1))
done
local PORT=$((8200 + i))
local SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
CUDA_VISIBLE_DEVICES=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $MODEL_NAME \
--port $PORT \
--enforce-eager \
--max-model-len $MAX_MODEL_LEN \
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config "$kv_config" \
--speculative-config "$DECODE_SPEC_CONFIG" \
--attention-backend FLASH_ATTN &
DECODE_HOSTS+=("localhost")
DECODE_PORTS+=("$PORT")
done
# Wait for servers
for PORT in "${PREFILL_PORTS[@]}"; do
wait_for_server "$PORT"
done
for PORT in "${DECODE_PORTS[@]}"; do
wait_for_server "$PORT"
done
# Start proxy
local PROXY_PORT=8192
echo "Starting proxy server on port $PROXY_PORT..."
python3 "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
--port $PROXY_PORT \
--prefiller-hosts ${PREFILL_HOSTS[*]} \
--prefiller-ports ${PREFILL_PORTS[*]} \
--decoder-hosts ${DECODE_HOSTS[*]} \
--decoder-ports ${DECODE_PORTS[*]} &
sleep 5
# Run test
echo "Running spec decode acceptance test (kv_buffer_device=${kv_device})..."
DECODE_PORT=${DECODE_PORTS[0]} \
TEST_MODEL=$MODEL_NAME \
python3 -m pytest -s -x "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py"
# Tear down before next iteration
cleanup_instances
sleep 3
}
# ── Main: loop over kv_buffer_device values ──────────────────────────────
for device in $KV_BUFFER_DEVICES; do
run_test_for_device "$device"
done
echo "=== All spec decode acceptance tests passed ==="

View File

@@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NixlConnector PD + EAGLE3 speculative decoding acceptance length test.
- Loads MT-Bench prompts (80 prompts, 256 output tokens)
- Sends through the PD proxy (completions API)
- Scrapes Prometheus metrics from the decode server
- Asserts acceptance length matches standalone EAGLE3 baselines
Baselines from tests/v1/spec_decode/test_acceptance_length.py
(standalone EAGLE3 with same model/drafter on MT-Bench, temp=0).
PD disaggregation via NixlConnector should match within tolerance.
Environment variables (set by spec_decode_acceptance_test.sh):
TEST_MODEL - target model name
DECODE_PORT - port of the decode vLLM server (for /metrics)
"""
import os
from dataclasses import dataclass, field
from types import SimpleNamespace
from urllib.request import urlopen
import openai
import regex as re
from transformers import AutoTokenizer
from vllm.benchmarks.datasets import get_samples
PROXY_BASE_URL = "http://localhost:8192/v1"
DECODE_PORT = os.environ.get("DECODE_PORT", "8200")
MODEL_NAME = os.environ.get("TEST_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
@dataclass
class Eagle3ModelConfig:
verifier: str
drafter: str
expected_acceptance_length: float
expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list)
id: str = ""
rtol: float | None = None
# Standalone EAGLE3 baselines (MT-Bench, 80 prompts, 256 tokens, temp=0).
# Source: tests/v1/spec_decode/test_acceptance_length.py
EAGLE3_MODEL_CONFIGS = [
Eagle3ModelConfig(
verifier="meta-llama/Llama-3.1-8B-Instruct",
drafter="RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
expected_acceptance_length=2.60,
expected_acceptance_lengths_per_pos=[0.7296, 0.5208, 0.3545],
id="llama3-8b-eagle3",
),
]
DEFAULT_NUM_PROMPTS = 80
DEFAULT_OUTPUT_LEN = 256
DEFAULT_RTOL = 0.05
def _get_model_config() -> Eagle3ModelConfig:
"""Get the model config matching MODEL_NAME."""
for config in EAGLE3_MODEL_CONFIGS:
if config.verifier == MODEL_NAME:
return config
raise ValueError(
f"No Eagle3ModelConfig found for model {MODEL_NAME}. "
f"Available: {[c.verifier for c in EAGLE3_MODEL_CONFIGS]}"
)
def _get_mt_bench_prompts() -> list[str]:
"""Load MT-Bench prompts via vllm.benchmarks.datasets.get_samples."""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
args = SimpleNamespace(
dataset_name="hf",
dataset_path="philschmid/mt-bench",
num_prompts=DEFAULT_NUM_PROMPTS,
seed=42,
no_oversample=False,
endpoint_type="openai-chat",
backend="openai-chat",
input_len=None,
output_len=DEFAULT_OUTPUT_LEN,
sharegpt_output_len=DEFAULT_OUTPUT_LEN,
hf_name=None,
hf_split="train",
hf_subset=None,
hf_output_len=DEFAULT_OUTPUT_LEN,
no_stream=True,
disable_shuffle=False,
skip_chat_template=False,
trust_remote_code=False,
enable_multimodal_chat=False,
request_id_prefix="",
)
samples = get_samples(args, tokenizer)
return [sample.prompt for sample in samples]
def _fetch_metric(metric_name: str) -> float:
"""Fetch a single counter metric from the decode server's /metrics."""
url = f"http://localhost:{DECODE_PORT}/metrics"
body = urlopen(url).read().decode()
for line in body.split("\n"):
if line.startswith(metric_name + "{") or line.startswith(metric_name + " "):
return float(line.rsplit(" ", 1)[-1])
raise ValueError(f"Metric {metric_name} not found in decode /metrics")
def _fetch_per_position_acceptance() -> dict[int, float]:
"""Fetch per-position acceptance counts from decode /metrics."""
url = f"http://localhost:{DECODE_PORT}/metrics"
body = urlopen(url).read().decode()
counts: dict[int, float] = {}
for line in body.split("\n"):
if (
"spec_decode_num_accepted_tokens_per_pos_total" in line
and not line.startswith("#")
):
m = re.search(r'position="(\d+)"', line)
if m:
counts[int(m.group(1))] = float(line.rsplit(" ", 1)[-1])
return counts
def test_spec_decode_acceptance_length():
"""Validate PD+SD acceptance length against standalone baseline.
Sends MT-Bench prompts through the PD proxy (completions API),
then checks that the decode server's speculative decoding metrics
match the known standalone baselines.
"""
config = _get_model_config()
rtol = config.rtol if config.rtol is not None else DEFAULT_RTOL
prompts = _get_mt_bench_prompts()
assert len(prompts) == DEFAULT_NUM_PROMPTS, (
f"Expected {DEFAULT_NUM_PROMPTS} prompts, got {len(prompts)}"
)
client = openai.OpenAI(api_key="EMPTY", base_url=PROXY_BASE_URL)
for i, prompt in enumerate(prompts):
resp = client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=DEFAULT_OUTPUT_LEN,
temperature=0.0,
top_p=1.0,
)
if i < 3:
text = resp.choices[0].text.strip()[:100]
print(f" [{i}] {prompt[:60]}... -> {text}...")
# ── Extract metrics from decode server ────────────────────────────
n_drafts = _fetch_metric("vllm:spec_decode_num_drafts_total")
n_accepted = _fetch_metric("vllm:spec_decode_num_accepted_tokens_total")
assert n_drafts > 0, "No spec-decode drafts were generated"
acceptance_length = 1 + (n_accepted / n_drafts)
per_pos_counts = _fetch_per_position_acceptance()
per_pos_rates = [
per_pos_counts.get(i, 0) / n_drafts
for i in range(len(config.expected_acceptance_lengths_per_pos))
]
# ── Report ────────────────────────────────────────────────────────
expected = config.expected_acceptance_length
expected_per_pos = config.expected_acceptance_lengths_per_pos
print(
f"\n{config.id}: acceptance_length={acceptance_length:.3f} "
f"(expected={expected:.3f})"
)
print(f" Drafts: {n_drafts:.0f}, Accepted: {n_accepted:.0f}")
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
print(f" Position {i}: {actual:.4f} (expected: {exp:.4f})")
# ── Assert overall acceptance length ──────────────────────────────
rel_error = abs(acceptance_length - expected) / expected
assert rel_error <= rtol, (
f"Acceptance length regression for {config.id}! "
f"Expected: {expected:.3f}, "
f"Got: {acceptance_length:.3f}, "
f"Relative error: {rel_error:.2%} (tolerance: {rtol:.0%}). "
f"This may indicate drafter KV was not correctly transferred."
)
# ── Assert per-position acceptance ────────────────────────────────
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
if exp > 0:
pos_err = abs(actual - exp) / exp
assert pos_err <= rtol, (
f"Per-position acceptance regression at position {i} "
f"for {config.id}! "
f"Expected: {exp:.4f}, Got: {actual:.4f}, "
f"Relative error: {pos_err:.2%} "
f"(tolerance: {rtol:.0%})"
)
print(
f"\n=== PASS: {config.id} acceptance length {acceptance_length:.3f} "
f"within {rtol:.0%} of {expected:.3f} ==="
)