[Spec Decode][KV Connector] Fix KV transfer in PD + speculative decoding (#35158)
Signed-off-by: Claude <noreply@anthropic.com> Signed-off-by: Zhanqiu Hu <zh338@cornell.edu> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
237
tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
Executable file
237
tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
Executable file
@@ -0,0 +1,237 @@
|
||||
#!/bin/bash
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# NixlConnector PD + speculative decoding acceptance length test.
|
||||
# Tests EAGLE3 acceptance length for both RDMA (cuda) and CPU host (cpu)
|
||||
# KV buffer device paths.
|
||||
#
|
||||
# For each kv_buffer_device setting, starts prefill + decode vllm servers
|
||||
# with NixlConnector, then runs test_spec_decode_acceptance.py to validate
|
||||
# acceptance length matches the standalone SD baseline.
|
||||
#
|
||||
# Usage:
|
||||
# CUDA_VISIBLE_DEVICES=0,1 bash tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
|
||||
#
|
||||
# Environment variables:
|
||||
# KV_BUFFER_DEVICES - space-separated list of devices to test
|
||||
# (default: "cuda cpu")
|
||||
# SD_METHOD - spec decode method (default: eagle3)
|
||||
# SD_MODEL - drafter model path
|
||||
# MODEL_NAME - target model (default: meta-llama/Llama-3.1-8B-Instruct)
|
||||
# NUM_SPEC_TOKENS - number of speculative tokens (default: 3)
|
||||
# GPU_MEMORY_UTILIZATION - (default: 0.7)
|
||||
set -x
|
||||
|
||||
# ── Model & spec decode config ──────────────────────────────────────────
|
||||
|
||||
MODEL_NAME="${MODEL_NAME:-meta-llama/Llama-3.1-8B-Instruct}"
|
||||
SD_METHOD="${SD_METHOD:-eagle3}"
|
||||
SD_MODEL="${SD_MODEL:-RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3}"
|
||||
NUM_SPEC_TOKENS="${NUM_SPEC_TOKENS:-3}"
|
||||
MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}"
|
||||
|
||||
PREFILL_SPEC_CONFIG="{\"method\":\"${SD_METHOD}\",\"model\":\"${SD_MODEL}\",\"num_speculative_tokens\":1,\"max_model_len\":${MAX_MODEL_LEN}}"
|
||||
DECODE_SPEC_CONFIG="{\"method\":\"${SD_METHOD}\",\"model\":\"${SD_MODEL}\",\"num_speculative_tokens\":${NUM_SPEC_TOKENS},\"max_model_len\":${MAX_MODEL_LEN}}"
|
||||
|
||||
# ── Test matrix ──────────────────────────────────────────────────────────
|
||||
|
||||
KV_BUFFER_DEVICES="${KV_BUFFER_DEVICES:-cuda cpu}"
|
||||
|
||||
# ── Cluster layout ───────────────────────────────────────────────────────
|
||||
|
||||
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1}
|
||||
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1}
|
||||
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
|
||||
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
||||
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.7}
|
||||
BLOCK_SIZE=${BLOCK_SIZE:-16}
|
||||
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
|
||||
|
||||
cleanup_instances() {
|
||||
echo ""
|
||||
echo "Cleaning up..."
|
||||
kill $(jobs -pr) 2>/dev/null || true
|
||||
sleep 1
|
||||
kill -9 $(jobs -pr) 2>/dev/null || true
|
||||
pkill -9 -f "vllm serve.*${MODEL_NAME}" 2>/dev/null || true
|
||||
pkill -9 -f "toy_proxy_server.*8192" 2>/dev/null || true
|
||||
sleep 1
|
||||
echo "Cleanup done."
|
||||
}
|
||||
trap cleanup_instances EXIT
|
||||
trap 'echo " Interrupted."; exit 130' INT TERM
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
local deadline=600
|
||||
local elapsed=0
|
||||
echo "Waiting for server on port ${port}..."
|
||||
while [ $elapsed -lt $deadline ]; do
|
||||
if curl -s "localhost:${port}/v1/completions" > /dev/null 2>&1; then
|
||||
echo "Server on port ${port} ready"
|
||||
return 0
|
||||
fi
|
||||
sleep 2
|
||||
elapsed=$((elapsed + 2))
|
||||
done
|
||||
echo "FAIL: Server on port ${port} did not start within ${deadline}s"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# ── Resolve GPU list ─────────────────────────────────────────────────────
|
||||
|
||||
if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then
|
||||
IFS=',' read -ra ALL_GPUS <<< "$CUDA_VISIBLE_DEVICES"
|
||||
else
|
||||
ALL_GPUS=()
|
||||
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
|
||||
num=$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)
|
||||
elif [[ "$SMI_BIN" == *"rocm"* ]]; then
|
||||
num=$($SMI_BIN -l | grep -c GPU)
|
||||
else
|
||||
num=1
|
||||
fi
|
||||
for (( g=0; g<num; g++ )); do ALL_GPUS+=($g); done
|
||||
fi
|
||||
|
||||
TOTAL_GPUS_NEEDED=$(( (NUM_PREFILL_INSTANCES * PREFILLER_TP_SIZE) + (NUM_DECODE_INSTANCES * DECODER_TP_SIZE) ))
|
||||
if [[ ${#ALL_GPUS[@]} -lt $TOTAL_GPUS_NEEDED ]]; then
|
||||
echo "FAIL: Need $TOTAL_GPUS_NEEDED GPUs but only have ${#ALL_GPUS[@]} (CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-not set})"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ── Run one test iteration ───────────────────────────────────────────────
|
||||
|
||||
run_test_for_device() {
|
||||
local kv_device=$1
|
||||
|
||||
if [[ "$kv_device" == "cuda" ]]; then
|
||||
local kv_config='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
else
|
||||
local kv_config="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${kv_device}\"}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "================================================================"
|
||||
echo "NixlConnector PD + Spec Decode Acceptance Test (kv_buffer_device=${kv_device})"
|
||||
echo "================================================================"
|
||||
echo "Model: ${MODEL_NAME}"
|
||||
echo "SD method: ${SD_METHOD}"
|
||||
echo "SD model: ${SD_MODEL}"
|
||||
echo "Spec tokens: ${NUM_SPEC_TOKENS}"
|
||||
echo "KV buffer device: ${kv_device}"
|
||||
echo "GPUs available: ${ALL_GPUS[*]}"
|
||||
echo "================================================================"
|
||||
|
||||
local PREFILL_HOSTS=()
|
||||
local PREFILL_PORTS=()
|
||||
local DECODE_HOSTS=()
|
||||
local DECODE_PORTS=()
|
||||
local GPU_IDX=0
|
||||
|
||||
# Start prefill instances
|
||||
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
|
||||
local GPU_ID="${ALL_GPUS[$GPU_IDX]}"
|
||||
GPU_IDX=$((GPU_IDX + 1))
|
||||
for (( j=1; j < PREFILLER_TP_SIZE; j++ )); do
|
||||
GPU_ID="${GPU_ID},${ALL_GPUS[$GPU_IDX]}"
|
||||
GPU_IDX=$((GPU_IDX + 1))
|
||||
done
|
||||
|
||||
local PORT=$((8100 + i))
|
||||
local SIDE_CHANNEL_PORT=$((5559 + i))
|
||||
|
||||
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
|
||||
CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $MODEL_NAME \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--max-model-len $MAX_MODEL_LEN \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--speculative-config "$PREFILL_SPEC_CONFIG" \
|
||||
--attention-backend FLASH_ATTN &
|
||||
|
||||
PREFILL_HOSTS+=("localhost")
|
||||
PREFILL_PORTS+=("$PORT")
|
||||
done
|
||||
|
||||
# Start decode instances
|
||||
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
|
||||
local GPU_ID="${ALL_GPUS[$GPU_IDX]}"
|
||||
GPU_IDX=$((GPU_IDX + 1))
|
||||
for (( j=1; j < DECODER_TP_SIZE; j++ )); do
|
||||
GPU_ID="${GPU_ID},${ALL_GPUS[$GPU_IDX]}"
|
||||
GPU_IDX=$((GPU_IDX + 1))
|
||||
done
|
||||
|
||||
local PORT=$((8200 + i))
|
||||
local SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
|
||||
|
||||
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
|
||||
CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $MODEL_NAME \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--max-model-len $MAX_MODEL_LEN \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $DECODER_TP_SIZE \
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--speculative-config "$DECODE_SPEC_CONFIG" \
|
||||
--attention-backend FLASH_ATTN &
|
||||
|
||||
DECODE_HOSTS+=("localhost")
|
||||
DECODE_PORTS+=("$PORT")
|
||||
done
|
||||
|
||||
# Wait for servers
|
||||
for PORT in "${PREFILL_PORTS[@]}"; do
|
||||
wait_for_server "$PORT"
|
||||
done
|
||||
for PORT in "${DECODE_PORTS[@]}"; do
|
||||
wait_for_server "$PORT"
|
||||
done
|
||||
|
||||
# Start proxy
|
||||
local PROXY_PORT=8192
|
||||
echo "Starting proxy server on port $PROXY_PORT..."
|
||||
python3 "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
|
||||
--port $PROXY_PORT \
|
||||
--prefiller-hosts ${PREFILL_HOSTS[*]} \
|
||||
--prefiller-ports ${PREFILL_PORTS[*]} \
|
||||
--decoder-hosts ${DECODE_HOSTS[*]} \
|
||||
--decoder-ports ${DECODE_PORTS[*]} &
|
||||
|
||||
sleep 5
|
||||
|
||||
# Run test
|
||||
echo "Running spec decode acceptance test (kv_buffer_device=${kv_device})..."
|
||||
DECODE_PORT=${DECODE_PORTS[0]} \
|
||||
TEST_MODEL=$MODEL_NAME \
|
||||
python3 -m pytest -s -x "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py"
|
||||
|
||||
# Tear down before next iteration
|
||||
cleanup_instances
|
||||
sleep 3
|
||||
}
|
||||
|
||||
# ── Main: loop over kv_buffer_device values ──────────────────────────────
|
||||
|
||||
for device in $KV_BUFFER_DEVICES; do
|
||||
run_test_for_device "$device"
|
||||
done
|
||||
|
||||
echo "=== All spec decode acceptance tests passed ==="
|
||||
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""NixlConnector PD + EAGLE3 speculative decoding acceptance length test.
|
||||
|
||||
- Loads MT-Bench prompts (80 prompts, 256 output tokens)
|
||||
- Sends through the PD proxy (completions API)
|
||||
- Scrapes Prometheus metrics from the decode server
|
||||
- Asserts acceptance length matches standalone EAGLE3 baselines
|
||||
|
||||
Baselines from tests/v1/spec_decode/test_acceptance_length.py
|
||||
(standalone EAGLE3 with same model/drafter on MT-Bench, temp=0).
|
||||
PD disaggregation via NixlConnector should match within tolerance.
|
||||
|
||||
Environment variables (set by spec_decode_acceptance_test.sh):
|
||||
TEST_MODEL - target model name
|
||||
DECODE_PORT - port of the decode vLLM server (for /metrics)
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from types import SimpleNamespace
|
||||
from urllib.request import urlopen
|
||||
|
||||
import openai
|
||||
import regex as re
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.benchmarks.datasets import get_samples
|
||||
|
||||
PROXY_BASE_URL = "http://localhost:8192/v1"
|
||||
DECODE_PORT = os.environ.get("DECODE_PORT", "8200")
|
||||
MODEL_NAME = os.environ.get("TEST_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Eagle3ModelConfig:
|
||||
verifier: str
|
||||
drafter: str
|
||||
expected_acceptance_length: float
|
||||
expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list)
|
||||
id: str = ""
|
||||
rtol: float | None = None
|
||||
|
||||
|
||||
# Standalone EAGLE3 baselines (MT-Bench, 80 prompts, 256 tokens, temp=0).
|
||||
# Source: tests/v1/spec_decode/test_acceptance_length.py
|
||||
EAGLE3_MODEL_CONFIGS = [
|
||||
Eagle3ModelConfig(
|
||||
verifier="meta-llama/Llama-3.1-8B-Instruct",
|
||||
drafter="RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
|
||||
expected_acceptance_length=2.60,
|
||||
expected_acceptance_lengths_per_pos=[0.7296, 0.5208, 0.3545],
|
||||
id="llama3-8b-eagle3",
|
||||
),
|
||||
]
|
||||
|
||||
DEFAULT_NUM_PROMPTS = 80
|
||||
DEFAULT_OUTPUT_LEN = 256
|
||||
DEFAULT_RTOL = 0.05
|
||||
|
||||
|
||||
def _get_model_config() -> Eagle3ModelConfig:
|
||||
"""Get the model config matching MODEL_NAME."""
|
||||
for config in EAGLE3_MODEL_CONFIGS:
|
||||
if config.verifier == MODEL_NAME:
|
||||
return config
|
||||
raise ValueError(
|
||||
f"No Eagle3ModelConfig found for model {MODEL_NAME}. "
|
||||
f"Available: {[c.verifier for c in EAGLE3_MODEL_CONFIGS]}"
|
||||
)
|
||||
|
||||
|
||||
def _get_mt_bench_prompts() -> list[str]:
|
||||
"""Load MT-Bench prompts via vllm.benchmarks.datasets.get_samples."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
args = SimpleNamespace(
|
||||
dataset_name="hf",
|
||||
dataset_path="philschmid/mt-bench",
|
||||
num_prompts=DEFAULT_NUM_PROMPTS,
|
||||
seed=42,
|
||||
no_oversample=False,
|
||||
endpoint_type="openai-chat",
|
||||
backend="openai-chat",
|
||||
input_len=None,
|
||||
output_len=DEFAULT_OUTPUT_LEN,
|
||||
sharegpt_output_len=DEFAULT_OUTPUT_LEN,
|
||||
hf_name=None,
|
||||
hf_split="train",
|
||||
hf_subset=None,
|
||||
hf_output_len=DEFAULT_OUTPUT_LEN,
|
||||
no_stream=True,
|
||||
disable_shuffle=False,
|
||||
skip_chat_template=False,
|
||||
trust_remote_code=False,
|
||||
enable_multimodal_chat=False,
|
||||
request_id_prefix="",
|
||||
)
|
||||
samples = get_samples(args, tokenizer)
|
||||
return [sample.prompt for sample in samples]
|
||||
|
||||
|
||||
def _fetch_metric(metric_name: str) -> float:
|
||||
"""Fetch a single counter metric from the decode server's /metrics."""
|
||||
url = f"http://localhost:{DECODE_PORT}/metrics"
|
||||
body = urlopen(url).read().decode()
|
||||
for line in body.split("\n"):
|
||||
if line.startswith(metric_name + "{") or line.startswith(metric_name + " "):
|
||||
return float(line.rsplit(" ", 1)[-1])
|
||||
raise ValueError(f"Metric {metric_name} not found in decode /metrics")
|
||||
|
||||
|
||||
def _fetch_per_position_acceptance() -> dict[int, float]:
|
||||
"""Fetch per-position acceptance counts from decode /metrics."""
|
||||
url = f"http://localhost:{DECODE_PORT}/metrics"
|
||||
body = urlopen(url).read().decode()
|
||||
counts: dict[int, float] = {}
|
||||
for line in body.split("\n"):
|
||||
if (
|
||||
"spec_decode_num_accepted_tokens_per_pos_total" in line
|
||||
and not line.startswith("#")
|
||||
):
|
||||
m = re.search(r'position="(\d+)"', line)
|
||||
if m:
|
||||
counts[int(m.group(1))] = float(line.rsplit(" ", 1)[-1])
|
||||
return counts
|
||||
|
||||
|
||||
def test_spec_decode_acceptance_length():
|
||||
"""Validate PD+SD acceptance length against standalone baseline.
|
||||
|
||||
Sends MT-Bench prompts through the PD proxy (completions API),
|
||||
then checks that the decode server's speculative decoding metrics
|
||||
match the known standalone baselines.
|
||||
"""
|
||||
config = _get_model_config()
|
||||
rtol = config.rtol if config.rtol is not None else DEFAULT_RTOL
|
||||
|
||||
prompts = _get_mt_bench_prompts()
|
||||
assert len(prompts) == DEFAULT_NUM_PROMPTS, (
|
||||
f"Expected {DEFAULT_NUM_PROMPTS} prompts, got {len(prompts)}"
|
||||
)
|
||||
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=PROXY_BASE_URL)
|
||||
for i, prompt in enumerate(prompts):
|
||||
resp = client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=DEFAULT_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
)
|
||||
if i < 3:
|
||||
text = resp.choices[0].text.strip()[:100]
|
||||
print(f" [{i}] {prompt[:60]}... -> {text}...")
|
||||
|
||||
# ── Extract metrics from decode server ────────────────────────────
|
||||
n_drafts = _fetch_metric("vllm:spec_decode_num_drafts_total")
|
||||
n_accepted = _fetch_metric("vllm:spec_decode_num_accepted_tokens_total")
|
||||
|
||||
assert n_drafts > 0, "No spec-decode drafts were generated"
|
||||
|
||||
acceptance_length = 1 + (n_accepted / n_drafts)
|
||||
|
||||
per_pos_counts = _fetch_per_position_acceptance()
|
||||
per_pos_rates = [
|
||||
per_pos_counts.get(i, 0) / n_drafts
|
||||
for i in range(len(config.expected_acceptance_lengths_per_pos))
|
||||
]
|
||||
|
||||
# ── Report ────────────────────────────────────────────────────────
|
||||
expected = config.expected_acceptance_length
|
||||
expected_per_pos = config.expected_acceptance_lengths_per_pos
|
||||
|
||||
print(
|
||||
f"\n{config.id}: acceptance_length={acceptance_length:.3f} "
|
||||
f"(expected={expected:.3f})"
|
||||
)
|
||||
print(f" Drafts: {n_drafts:.0f}, Accepted: {n_accepted:.0f}")
|
||||
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
|
||||
print(f" Position {i}: {actual:.4f} (expected: {exp:.4f})")
|
||||
|
||||
# ── Assert overall acceptance length ──────────────────────────────
|
||||
rel_error = abs(acceptance_length - expected) / expected
|
||||
|
||||
assert rel_error <= rtol, (
|
||||
f"Acceptance length regression for {config.id}! "
|
||||
f"Expected: {expected:.3f}, "
|
||||
f"Got: {acceptance_length:.3f}, "
|
||||
f"Relative error: {rel_error:.2%} (tolerance: {rtol:.0%}). "
|
||||
f"This may indicate drafter KV was not correctly transferred."
|
||||
)
|
||||
|
||||
# ── Assert per-position acceptance ────────────────────────────────
|
||||
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
|
||||
if exp > 0:
|
||||
pos_err = abs(actual - exp) / exp
|
||||
assert pos_err <= rtol, (
|
||||
f"Per-position acceptance regression at position {i} "
|
||||
f"for {config.id}! "
|
||||
f"Expected: {exp:.4f}, Got: {actual:.4f}, "
|
||||
f"Relative error: {pos_err:.2%} "
|
||||
f"(tolerance: {rtol:.0%})"
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n=== PASS: {config.id} acceptance length {acceptance_length:.3f} "
|
||||
f"within {rtol:.0%} of {expected:.3f} ==="
|
||||
)
|
||||
Reference in New Issue
Block a user