[v1] [P/D] Adding LMCache KV connector for v1 (#16625)

This commit is contained in:
Yihua Cheng
2025-04-25 22:03:38 -05:00
committed by GitHub
parent 68af5f6c5c
commit 5e83a7277f
12 changed files with 793 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
local_cpu: False
max_local_cpu_size: 0
#local_disk:
max_local_disk_size: 0
remote_serde: NULL
enable_nixl: True
nixl_role: "receiver"
nixl_peer_host: "localhost"
nixl_peer_port: 55555
nixl_buffer_size: 1073741824 # 1GB
nixl_buffer_device: "cuda"
nixl_enable_gc: True

View File

@@ -0,0 +1,13 @@
local_cpu: False
max_local_cpu_size: 0
#local_disk:
max_local_disk_size: 0
remote_serde: NULL
enable_nixl: True
nixl_role: "sender"
nixl_peer_host: "localhost"
nixl_peer_port: 55555
nixl_buffer_size: 1073741824 # 1GB
nixl_buffer_device: "cuda"
nixl_enable_gc: True

View File

@@ -0,0 +1,136 @@
#!/bin/bash
echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change."
PIDS=()
# Switch to the directory of the current script
cd "$(dirname "${BASH_SOURCE[0]}")"
check_hf_token() {
if [ -z "$HF_TOKEN" ]; then
echo "HF_TOKEN is not set. Please set it to your Hugging Face token."
exit 1
fi
if [[ "$HF_TOKEN" != hf_* ]]; then
echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token."
exit 1
fi
echo "HF_TOKEN is set and valid."
}
check_num_gpus() {
# can you check if the number of GPUs are >=2 via nvidia-smi?
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
if [ "$num_gpus" -lt 2 ]; then
echo "You need at least 2 GPUs to run disaggregated prefill."
exit 1
else
echo "Found $num_gpus GPUs."
fi
}
ensure_python_library_installed() {
echo "Checking if $1 is installed..."
python -c "import $1" > /dev/null 2>&1
if [ $? -ne 0 ]; then
if [ "$1" == "nixl" ]; then
echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation."
else
echo "$1 is not installed. Please install it via pip install $1."
fi
exit 1
else
echo "$1 is installed."
fi
}
cleanup() {
echo "Stopping everything…"
trap - INT TERM # prevent re-entrancy
kill -- -$$ # negative PID == “this whole process-group”
wait # reap children so we don't leave zombies
exit 0
}
wait_for_server() {
local port=$1
local timeout_seconds=1200
local start_time=$(date +%s)
echo "Waiting for server on port $port..."
while true; do
if curl -s "localhost:${port}/v1/completions" > /dev/null; then
return 0
fi
local now=$(date +%s)
if (( now - start_time >= timeout_seconds )); then
echo "Timeout waiting for server"
return 1
fi
sleep 1
done
}
main() {
check_hf_token
check_num_gpus
ensure_python_library_installed lmcache
ensure_python_library_installed nixl
ensure_python_library_installed pandas
ensure_python_library_installed datasets
ensure_python_library_installed vllm
trap cleanup INT
trap cleanup USR1
trap cleanup TERM
echo "Launching prefiller, decoder and proxy..."
echo "Please check prefiller.log, decoder.log and proxy.log for logs."
bash disagg_vllm_launcher.sh prefiller \
> >(tee prefiller.log) 2>&1 &
prefiller_pid=$!
PIDS+=($prefiller_pid)
bash disagg_vllm_launcher.sh decoder \
> >(tee decoder.log) 2>&1 &
decoder_pid=$!
PIDS+=($decoder_pid)
python3 disagg_proxy_server.py \
--host localhost \
--port 9000 \
--prefiller-host localhost \
--prefiller-port 8100 \
--decoder-host localhost \
--decoder-port 8200 \
> >(tee proxy.log) 2>&1 &
proxy_pid=$!
PIDS+=($proxy_pid)
wait_for_server 8100
wait_for_server 8200
wait_for_server 9000
echo "All servers are up. Starting benchmark..."
# begin benchmark
cd ../../../benchmarks/
python benchmark_serving.py --port 9000 --seed $(date +%s) \
--model meta-llama/Llama-3.1-8B-Instruct \
--dataset-name random --random-input-len 7500 --random-output-len 200 \
--num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log
echo "Benchmarking done. Cleaning up..."
cleanup
}
main

View File

@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import time
from contextlib import asynccontextmanager
import httpx
import numpy as np
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize clients
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'
app.state.prefill_client = httpx.AsyncClient(timeout=None,
base_url=prefiller_base_url)
app.state.decode_client = httpx.AsyncClient(timeout=None,
base_url=decoder_base_url)
yield
# Shutdown: Close clients
await app.state.prefill_client.aclose()
await app.state.decode_client.aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
class StatsCalculator:
def __init__(self):
self._stats = []
self._last_log_time = time.time()
def add(self, value):
self._stats.append(value)
if time.time() - self._last_log_time > 5:
self._log_stats()
self._last_log_time = time.time()
def _log_stats(self):
# Print average, median, and 99th percentile
np_arr = np.array(self._stats)
output_str = f"\nNum requests: {len(self._stats)}" + \
"\nPrefill node TTFT stats:" + \
f"\n - Average (ms): {np.mean(np_arr)}" + \
f"\n - Median (ms): {np.median(np_arr)}" + \
f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
print("===============================", output_str,
"===============================")
stats_calculator = StatsCalculator()
counter = 0
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-host", type=str, default="localhost")
parser.add_argument("--prefiller-port", type=int, default=8100)
parser.add_argument("--decoder-host", type=str, default="localhost")
parser.add_argument("--decoder-port", type=int, default=8200)
args = parser.parse_args()
return args
# Initialize variables to hold the persistent clients
app.state.prefill_client = None
app.state.decode_client = None
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
req_data: dict):
"""
Send a request to a service using a persistent client.
"""
req_data = req_data.copy()
req_data['max_tokens'] = 1
if 'max_completion_tokens' in req_data:
req_data['max_completion_tokens'] = 1
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()
return response
async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
req_data: dict):
"""
Asynchronously stream the response from a service using a persistent client.
"""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
async with client.stream("POST", endpoint, json=req_data,
headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
@app.post("/v1/completions")
async def handle_completions(request: Request):
global counter, stats_calculator
counter += 1
st = time.time()
try:
req_data = await request.json()
# Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client, "/completions",
req_data)
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client,
"/completions",
req_data):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
global counter, stats_calculator
counter += 1
st = time.time()
try:
req_data = await request.json()
# Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client,
"/chat/completions", req_data)
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client,
"/chat/completions",
req_data):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server "
" - chat completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@@ -0,0 +1,59 @@
#!/bin/bash
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
if [[ $# -lt 1 ]]; then
echo "Usage: $0 <prefiller | decoder> [model]"
exit 1
fi
if [[ $# -eq 1 ]]; then
echo "Using default model: meta-llama/Llama-3.1-8B-Instruct"
MODEL="meta-llama/Llama-3.1-8B-Instruct"
else
echo "Using model: $2"
MODEL=$2
fi
if [[ $1 == "prefiller" ]]; then
# Prefiller listens on port 8100
prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml
UCX_TLS=cuda_ipc,cuda_copy,tcp \
LMCACHE_CONFIG_FILE=$prefill_config_file \
LMCACHE_USE_EXPERIMENTAL=True \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=0 \
vllm serve $MODEL \
--port 8100 \
--disable-log-requests \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}'
elif [[ $1 == "decoder" ]]; then
# Decoder listens on port 8200
decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml
UCX_TLS=cuda_ipc,cuda_copy,tcp \
LMCACHE_CONFIG_FILE=$decode_config_file \
LMCACHE_USE_EXPERIMENTAL=True \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=1 \
vllm serve $MODEL \
--port 8200 \
--disable-log-requests \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}'
else
echo "Invalid role: $1"
echo "Should be either prefill, decode"
exit 1
fi