[Spec Decode][KV Connector] Fix KV transfer in PD + speculative decoding (#35158)
Signed-off-by: Claude <noreply@anthropic.com> Signed-off-by: Zhanqiu Hu <zh338@cornell.edu> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -213,6 +213,19 @@ steps:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- CROSS_LAYERS_BLOCKS=True bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh
|
||||
|
||||
- label: NixlConnector PD + Spec Decode acceptance (2 GPUs)
|
||||
timeout_in_minutes: 30
|
||||
device: a100
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_devices: 2
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
|
||||
- vllm/v1/worker/kv_connector_model_runner_mixin.py
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
|
||||
|
||||
- label: Pipeline + Context Parallelism (4 GPUs)
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
237
tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
Executable file
237
tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
Executable file
@@ -0,0 +1,237 @@
|
||||
#!/bin/bash
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# NixlConnector PD + speculative decoding acceptance length test.
|
||||
# Tests EAGLE3 acceptance length for both RDMA (cuda) and CPU host (cpu)
|
||||
# KV buffer device paths.
|
||||
#
|
||||
# For each kv_buffer_device setting, starts prefill + decode vllm servers
|
||||
# with NixlConnector, then runs test_spec_decode_acceptance.py to validate
|
||||
# acceptance length matches the standalone SD baseline.
|
||||
#
|
||||
# Usage:
|
||||
# CUDA_VISIBLE_DEVICES=0,1 bash tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh
|
||||
#
|
||||
# Environment variables:
|
||||
# KV_BUFFER_DEVICES - space-separated list of devices to test
|
||||
# (default: "cuda cpu")
|
||||
# SD_METHOD - spec decode method (default: eagle3)
|
||||
# SD_MODEL - drafter model path
|
||||
# MODEL_NAME - target model (default: meta-llama/Llama-3.1-8B-Instruct)
|
||||
# NUM_SPEC_TOKENS - number of speculative tokens (default: 3)
|
||||
# GPU_MEMORY_UTILIZATION - (default: 0.7)
|
||||
set -x
|
||||
|
||||
# ── Model & spec decode config ──────────────────────────────────────────
|
||||
|
||||
MODEL_NAME="${MODEL_NAME:-meta-llama/Llama-3.1-8B-Instruct}"
|
||||
SD_METHOD="${SD_METHOD:-eagle3}"
|
||||
SD_MODEL="${SD_MODEL:-RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3}"
|
||||
NUM_SPEC_TOKENS="${NUM_SPEC_TOKENS:-3}"
|
||||
MAX_MODEL_LEN="${MAX_MODEL_LEN:-16384}"
|
||||
|
||||
PREFILL_SPEC_CONFIG="{\"method\":\"${SD_METHOD}\",\"model\":\"${SD_MODEL}\",\"num_speculative_tokens\":1,\"max_model_len\":${MAX_MODEL_LEN}}"
|
||||
DECODE_SPEC_CONFIG="{\"method\":\"${SD_METHOD}\",\"model\":\"${SD_MODEL}\",\"num_speculative_tokens\":${NUM_SPEC_TOKENS},\"max_model_len\":${MAX_MODEL_LEN}}"
|
||||
|
||||
# ── Test matrix ──────────────────────────────────────────────────────────
|
||||
|
||||
KV_BUFFER_DEVICES="${KV_BUFFER_DEVICES:-cuda cpu}"
|
||||
|
||||
# ── Cluster layout ───────────────────────────────────────────────────────
|
||||
|
||||
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1}
|
||||
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1}
|
||||
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
|
||||
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
||||
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.7}
|
||||
BLOCK_SIZE=${BLOCK_SIZE:-16}
|
||||
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
|
||||
|
||||
cleanup_instances() {
|
||||
echo ""
|
||||
echo "Cleaning up..."
|
||||
kill $(jobs -pr) 2>/dev/null || true
|
||||
sleep 1
|
||||
kill -9 $(jobs -pr) 2>/dev/null || true
|
||||
pkill -9 -f "vllm serve.*${MODEL_NAME}" 2>/dev/null || true
|
||||
pkill -9 -f "toy_proxy_server.*8192" 2>/dev/null || true
|
||||
sleep 1
|
||||
echo "Cleanup done."
|
||||
}
|
||||
trap cleanup_instances EXIT
|
||||
trap 'echo " Interrupted."; exit 130' INT TERM
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
local deadline=600
|
||||
local elapsed=0
|
||||
echo "Waiting for server on port ${port}..."
|
||||
while [ $elapsed -lt $deadline ]; do
|
||||
if curl -s "localhost:${port}/v1/completions" > /dev/null 2>&1; then
|
||||
echo "Server on port ${port} ready"
|
||||
return 0
|
||||
fi
|
||||
sleep 2
|
||||
elapsed=$((elapsed + 2))
|
||||
done
|
||||
echo "FAIL: Server on port ${port} did not start within ${deadline}s"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# ── Resolve GPU list ─────────────────────────────────────────────────────
|
||||
|
||||
if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then
|
||||
IFS=',' read -ra ALL_GPUS <<< "$CUDA_VISIBLE_DEVICES"
|
||||
else
|
||||
ALL_GPUS=()
|
||||
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
|
||||
num=$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)
|
||||
elif [[ "$SMI_BIN" == *"rocm"* ]]; then
|
||||
num=$($SMI_BIN -l | grep -c GPU)
|
||||
else
|
||||
num=1
|
||||
fi
|
||||
for (( g=0; g<num; g++ )); do ALL_GPUS+=($g); done
|
||||
fi
|
||||
|
||||
TOTAL_GPUS_NEEDED=$(( (NUM_PREFILL_INSTANCES * PREFILLER_TP_SIZE) + (NUM_DECODE_INSTANCES * DECODER_TP_SIZE) ))
|
||||
if [[ ${#ALL_GPUS[@]} -lt $TOTAL_GPUS_NEEDED ]]; then
|
||||
echo "FAIL: Need $TOTAL_GPUS_NEEDED GPUs but only have ${#ALL_GPUS[@]} (CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-not set})"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ── Run one test iteration ───────────────────────────────────────────────
|
||||
|
||||
run_test_for_device() {
|
||||
local kv_device=$1
|
||||
|
||||
if [[ "$kv_device" == "cuda" ]]; then
|
||||
local kv_config='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
else
|
||||
local kv_config="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${kv_device}\"}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "================================================================"
|
||||
echo "NixlConnector PD + Spec Decode Acceptance Test (kv_buffer_device=${kv_device})"
|
||||
echo "================================================================"
|
||||
echo "Model: ${MODEL_NAME}"
|
||||
echo "SD method: ${SD_METHOD}"
|
||||
echo "SD model: ${SD_MODEL}"
|
||||
echo "Spec tokens: ${NUM_SPEC_TOKENS}"
|
||||
echo "KV buffer device: ${kv_device}"
|
||||
echo "GPUs available: ${ALL_GPUS[*]}"
|
||||
echo "================================================================"
|
||||
|
||||
local PREFILL_HOSTS=()
|
||||
local PREFILL_PORTS=()
|
||||
local DECODE_HOSTS=()
|
||||
local DECODE_PORTS=()
|
||||
local GPU_IDX=0
|
||||
|
||||
# Start prefill instances
|
||||
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
|
||||
local GPU_ID="${ALL_GPUS[$GPU_IDX]}"
|
||||
GPU_IDX=$((GPU_IDX + 1))
|
||||
for (( j=1; j < PREFILLER_TP_SIZE; j++ )); do
|
||||
GPU_ID="${GPU_ID},${ALL_GPUS[$GPU_IDX]}"
|
||||
GPU_IDX=$((GPU_IDX + 1))
|
||||
done
|
||||
|
||||
local PORT=$((8100 + i))
|
||||
local SIDE_CHANNEL_PORT=$((5559 + i))
|
||||
|
||||
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
|
||||
CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $MODEL_NAME \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--max-model-len $MAX_MODEL_LEN \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--speculative-config "$PREFILL_SPEC_CONFIG" \
|
||||
--attention-backend FLASH_ATTN &
|
||||
|
||||
PREFILL_HOSTS+=("localhost")
|
||||
PREFILL_PORTS+=("$PORT")
|
||||
done
|
||||
|
||||
# Start decode instances
|
||||
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
|
||||
local GPU_ID="${ALL_GPUS[$GPU_IDX]}"
|
||||
GPU_IDX=$((GPU_IDX + 1))
|
||||
for (( j=1; j < DECODER_TP_SIZE; j++ )); do
|
||||
GPU_ID="${GPU_ID},${ALL_GPUS[$GPU_IDX]}"
|
||||
GPU_IDX=$((GPU_IDX + 1))
|
||||
done
|
||||
|
||||
local PORT=$((8200 + i))
|
||||
local SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
|
||||
|
||||
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
|
||||
CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $MODEL_NAME \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--max-model-len $MAX_MODEL_LEN \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $DECODER_TP_SIZE \
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--speculative-config "$DECODE_SPEC_CONFIG" \
|
||||
--attention-backend FLASH_ATTN &
|
||||
|
||||
DECODE_HOSTS+=("localhost")
|
||||
DECODE_PORTS+=("$PORT")
|
||||
done
|
||||
|
||||
# Wait for servers
|
||||
for PORT in "${PREFILL_PORTS[@]}"; do
|
||||
wait_for_server "$PORT"
|
||||
done
|
||||
for PORT in "${DECODE_PORTS[@]}"; do
|
||||
wait_for_server "$PORT"
|
||||
done
|
||||
|
||||
# Start proxy
|
||||
local PROXY_PORT=8192
|
||||
echo "Starting proxy server on port $PROXY_PORT..."
|
||||
python3 "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
|
||||
--port $PROXY_PORT \
|
||||
--prefiller-hosts ${PREFILL_HOSTS[*]} \
|
||||
--prefiller-ports ${PREFILL_PORTS[*]} \
|
||||
--decoder-hosts ${DECODE_HOSTS[*]} \
|
||||
--decoder-ports ${DECODE_PORTS[*]} &
|
||||
|
||||
sleep 5
|
||||
|
||||
# Run test
|
||||
echo "Running spec decode acceptance test (kv_buffer_device=${kv_device})..."
|
||||
DECODE_PORT=${DECODE_PORTS[0]} \
|
||||
TEST_MODEL=$MODEL_NAME \
|
||||
python3 -m pytest -s -x "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_spec_decode_acceptance.py"
|
||||
|
||||
# Tear down before next iteration
|
||||
cleanup_instances
|
||||
sleep 3
|
||||
}
|
||||
|
||||
# ── Main: loop over kv_buffer_device values ──────────────────────────────
|
||||
|
||||
for device in $KV_BUFFER_DEVICES; do
|
||||
run_test_for_device "$device"
|
||||
done
|
||||
|
||||
echo "=== All spec decode acceptance tests passed ==="
|
||||
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""NixlConnector PD + EAGLE3 speculative decoding acceptance length test.
|
||||
|
||||
- Loads MT-Bench prompts (80 prompts, 256 output tokens)
|
||||
- Sends through the PD proxy (completions API)
|
||||
- Scrapes Prometheus metrics from the decode server
|
||||
- Asserts acceptance length matches standalone EAGLE3 baselines
|
||||
|
||||
Baselines from tests/v1/spec_decode/test_acceptance_length.py
|
||||
(standalone EAGLE3 with same model/drafter on MT-Bench, temp=0).
|
||||
PD disaggregation via NixlConnector should match within tolerance.
|
||||
|
||||
Environment variables (set by spec_decode_acceptance_test.sh):
|
||||
TEST_MODEL - target model name
|
||||
DECODE_PORT - port of the decode vLLM server (for /metrics)
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from types import SimpleNamespace
|
||||
from urllib.request import urlopen
|
||||
|
||||
import openai
|
||||
import regex as re
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.benchmarks.datasets import get_samples
|
||||
|
||||
PROXY_BASE_URL = "http://localhost:8192/v1"
|
||||
DECODE_PORT = os.environ.get("DECODE_PORT", "8200")
|
||||
MODEL_NAME = os.environ.get("TEST_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Eagle3ModelConfig:
|
||||
verifier: str
|
||||
drafter: str
|
||||
expected_acceptance_length: float
|
||||
expected_acceptance_lengths_per_pos: list[float] = field(default_factory=list)
|
||||
id: str = ""
|
||||
rtol: float | None = None
|
||||
|
||||
|
||||
# Standalone EAGLE3 baselines (MT-Bench, 80 prompts, 256 tokens, temp=0).
|
||||
# Source: tests/v1/spec_decode/test_acceptance_length.py
|
||||
EAGLE3_MODEL_CONFIGS = [
|
||||
Eagle3ModelConfig(
|
||||
verifier="meta-llama/Llama-3.1-8B-Instruct",
|
||||
drafter="RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
|
||||
expected_acceptance_length=2.60,
|
||||
expected_acceptance_lengths_per_pos=[0.7296, 0.5208, 0.3545],
|
||||
id="llama3-8b-eagle3",
|
||||
),
|
||||
]
|
||||
|
||||
DEFAULT_NUM_PROMPTS = 80
|
||||
DEFAULT_OUTPUT_LEN = 256
|
||||
DEFAULT_RTOL = 0.05
|
||||
|
||||
|
||||
def _get_model_config() -> Eagle3ModelConfig:
|
||||
"""Get the model config matching MODEL_NAME."""
|
||||
for config in EAGLE3_MODEL_CONFIGS:
|
||||
if config.verifier == MODEL_NAME:
|
||||
return config
|
||||
raise ValueError(
|
||||
f"No Eagle3ModelConfig found for model {MODEL_NAME}. "
|
||||
f"Available: {[c.verifier for c in EAGLE3_MODEL_CONFIGS]}"
|
||||
)
|
||||
|
||||
|
||||
def _get_mt_bench_prompts() -> list[str]:
|
||||
"""Load MT-Bench prompts via vllm.benchmarks.datasets.get_samples."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
args = SimpleNamespace(
|
||||
dataset_name="hf",
|
||||
dataset_path="philschmid/mt-bench",
|
||||
num_prompts=DEFAULT_NUM_PROMPTS,
|
||||
seed=42,
|
||||
no_oversample=False,
|
||||
endpoint_type="openai-chat",
|
||||
backend="openai-chat",
|
||||
input_len=None,
|
||||
output_len=DEFAULT_OUTPUT_LEN,
|
||||
sharegpt_output_len=DEFAULT_OUTPUT_LEN,
|
||||
hf_name=None,
|
||||
hf_split="train",
|
||||
hf_subset=None,
|
||||
hf_output_len=DEFAULT_OUTPUT_LEN,
|
||||
no_stream=True,
|
||||
disable_shuffle=False,
|
||||
skip_chat_template=False,
|
||||
trust_remote_code=False,
|
||||
enable_multimodal_chat=False,
|
||||
request_id_prefix="",
|
||||
)
|
||||
samples = get_samples(args, tokenizer)
|
||||
return [sample.prompt for sample in samples]
|
||||
|
||||
|
||||
def _fetch_metric(metric_name: str) -> float:
|
||||
"""Fetch a single counter metric from the decode server's /metrics."""
|
||||
url = f"http://localhost:{DECODE_PORT}/metrics"
|
||||
body = urlopen(url).read().decode()
|
||||
for line in body.split("\n"):
|
||||
if line.startswith(metric_name + "{") or line.startswith(metric_name + " "):
|
||||
return float(line.rsplit(" ", 1)[-1])
|
||||
raise ValueError(f"Metric {metric_name} not found in decode /metrics")
|
||||
|
||||
|
||||
def _fetch_per_position_acceptance() -> dict[int, float]:
|
||||
"""Fetch per-position acceptance counts from decode /metrics."""
|
||||
url = f"http://localhost:{DECODE_PORT}/metrics"
|
||||
body = urlopen(url).read().decode()
|
||||
counts: dict[int, float] = {}
|
||||
for line in body.split("\n"):
|
||||
if (
|
||||
"spec_decode_num_accepted_tokens_per_pos_total" in line
|
||||
and not line.startswith("#")
|
||||
):
|
||||
m = re.search(r'position="(\d+)"', line)
|
||||
if m:
|
||||
counts[int(m.group(1))] = float(line.rsplit(" ", 1)[-1])
|
||||
return counts
|
||||
|
||||
|
||||
def test_spec_decode_acceptance_length():
|
||||
"""Validate PD+SD acceptance length against standalone baseline.
|
||||
|
||||
Sends MT-Bench prompts through the PD proxy (completions API),
|
||||
then checks that the decode server's speculative decoding metrics
|
||||
match the known standalone baselines.
|
||||
"""
|
||||
config = _get_model_config()
|
||||
rtol = config.rtol if config.rtol is not None else DEFAULT_RTOL
|
||||
|
||||
prompts = _get_mt_bench_prompts()
|
||||
assert len(prompts) == DEFAULT_NUM_PROMPTS, (
|
||||
f"Expected {DEFAULT_NUM_PROMPTS} prompts, got {len(prompts)}"
|
||||
)
|
||||
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=PROXY_BASE_URL)
|
||||
for i, prompt in enumerate(prompts):
|
||||
resp = client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=DEFAULT_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
)
|
||||
if i < 3:
|
||||
text = resp.choices[0].text.strip()[:100]
|
||||
print(f" [{i}] {prompt[:60]}... -> {text}...")
|
||||
|
||||
# ── Extract metrics from decode server ────────────────────────────
|
||||
n_drafts = _fetch_metric("vllm:spec_decode_num_drafts_total")
|
||||
n_accepted = _fetch_metric("vllm:spec_decode_num_accepted_tokens_total")
|
||||
|
||||
assert n_drafts > 0, "No spec-decode drafts were generated"
|
||||
|
||||
acceptance_length = 1 + (n_accepted / n_drafts)
|
||||
|
||||
per_pos_counts = _fetch_per_position_acceptance()
|
||||
per_pos_rates = [
|
||||
per_pos_counts.get(i, 0) / n_drafts
|
||||
for i in range(len(config.expected_acceptance_lengths_per_pos))
|
||||
]
|
||||
|
||||
# ── Report ────────────────────────────────────────────────────────
|
||||
expected = config.expected_acceptance_length
|
||||
expected_per_pos = config.expected_acceptance_lengths_per_pos
|
||||
|
||||
print(
|
||||
f"\n{config.id}: acceptance_length={acceptance_length:.3f} "
|
||||
f"(expected={expected:.3f})"
|
||||
)
|
||||
print(f" Drafts: {n_drafts:.0f}, Accepted: {n_accepted:.0f}")
|
||||
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
|
||||
print(f" Position {i}: {actual:.4f} (expected: {exp:.4f})")
|
||||
|
||||
# ── Assert overall acceptance length ──────────────────────────────
|
||||
rel_error = abs(acceptance_length - expected) / expected
|
||||
|
||||
assert rel_error <= rtol, (
|
||||
f"Acceptance length regression for {config.id}! "
|
||||
f"Expected: {expected:.3f}, "
|
||||
f"Got: {acceptance_length:.3f}, "
|
||||
f"Relative error: {rel_error:.2%} (tolerance: {rtol:.0%}). "
|
||||
f"This may indicate drafter KV was not correctly transferred."
|
||||
)
|
||||
|
||||
# ── Assert per-position acceptance ────────────────────────────────
|
||||
for i, (actual, exp) in enumerate(zip(per_pos_rates, expected_per_pos)):
|
||||
if exp > 0:
|
||||
pos_err = abs(actual - exp) / exp
|
||||
assert pos_err <= rtol, (
|
||||
f"Per-position acceptance regression at position {i} "
|
||||
f"for {config.id}! "
|
||||
f"Expected: {exp:.4f}, Got: {actual:.4f}, "
|
||||
f"Relative error: {pos_err:.2%} "
|
||||
f"(tolerance: {rtol:.0%})"
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n=== PASS: {config.id} acceptance length {acceptance_length:.3f} "
|
||||
f"within {rtol:.0%} of {expected:.3f} ==="
|
||||
)
|
||||
@@ -3593,9 +3593,9 @@ class GPUModelRunner(
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
# When spec decode is enabled, delay clearing connector metadata
|
||||
# until after draft model runs in sample_tokens.
|
||||
clear_kv_metadata = self.speculative_config is None
|
||||
# When spec decode is enabled, defer connector finalization
|
||||
# (wait_for_save + clear metadata) until after draft model runs.
|
||||
defer_kv_connector_finalize = self.speculative_config is not None
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata,
|
||||
@@ -3610,7 +3610,8 @@ class GPUModelRunner(
|
||||
),
|
||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||
self.maybe_get_kv_connector_output(
|
||||
scheduler_output, clear_metadata=clear_kv_metadata
|
||||
scheduler_output,
|
||||
defer_finalize=defer_kv_connector_finalize,
|
||||
) as kv_connector_output,
|
||||
):
|
||||
model_output = self._model_forward(
|
||||
@@ -3843,11 +3844,11 @@ class GPUModelRunner(
|
||||
# tokens on the CPU, so they are run after bookkeeping.
|
||||
propose_draft_token_ids(valid_sampled_token_ids)
|
||||
|
||||
# Clear KV connector metadata after draft model runs (if spec decode).
|
||||
# This was deferred from target model forward to allow draft model
|
||||
# to also save its KV cache.
|
||||
if self.speculative_config is not None:
|
||||
self.clear_kv_connector_metadata()
|
||||
# Finalize KV connector (wait_for_save + clear metadata) after
|
||||
# draft model runs. Deferred from target model forward to allow
|
||||
# draft model to also save its KV cache.
|
||||
if spec_config is not None:
|
||||
self.finalize_kv_connector()
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: eplb"):
|
||||
self.eplb_step()
|
||||
|
||||
@@ -67,16 +67,27 @@ class KVConnectorModelRunnerMixin:
|
||||
@staticmethod
|
||||
def maybe_get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
clear_metadata: bool = True,
|
||||
defer_finalize: bool = False,
|
||||
) -> AbstractContextManager[KVConnectorOutput | None]:
|
||||
return (
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output(
|
||||
scheduler_output, clear_metadata=clear_metadata
|
||||
scheduler_output, defer_finalize=defer_finalize
|
||||
)
|
||||
if has_kv_transfer_group()
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def finalize_kv_connector() -> None:
|
||||
"""Finalize the KV connector: wait_for_save and clear metadata.
|
||||
|
||||
Call after draft model forward when defer_finalize=True was used.
|
||||
"""
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
kv_connector.wait_for_save()
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
# This context manager must be used within an active forward context.
|
||||
# It encapsulates the entire KV connector lifecycle within execute_model
|
||||
@staticmethod
|
||||
@@ -84,7 +95,7 @@ class KVConnectorModelRunnerMixin:
|
||||
def _get_kv_connector_output(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
wait_for_save: bool = True,
|
||||
clear_metadata: bool = True,
|
||||
defer_finalize: bool = False,
|
||||
) -> Generator[KVConnectorOutput, None, None]:
|
||||
output = KVConnectorOutput()
|
||||
|
||||
@@ -102,7 +113,7 @@ class KVConnectorModelRunnerMixin:
|
||||
try:
|
||||
yield output
|
||||
finally:
|
||||
if wait_for_save:
|
||||
if wait_for_save and not defer_finalize:
|
||||
kv_connector.wait_for_save()
|
||||
|
||||
output.finished_sending, output.finished_recving = (
|
||||
@@ -113,16 +124,9 @@ class KVConnectorModelRunnerMixin:
|
||||
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
|
||||
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
if clear_metadata:
|
||||
if not defer_finalize:
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
@staticmethod
|
||||
def clear_kv_connector_metadata() -> None:
|
||||
"""Clear the KV connector metadata. Call after draft model runs."""
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
@staticmethod
|
||||
def use_uniform_kv_cache(
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
|
||||
Reference in New Issue
Block a user