[P/D] rework mooncake connector and introduce its bootstrap server (#31034)
Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
@@ -19,12 +19,13 @@ Two main reasons:
|
||||
|
||||
Please refer to [examples/online_serving/disaggregated_prefill.sh](../../examples/online_serving/disaggregated_prefill.sh) for the example usage of disaggregated prefilling.
|
||||
|
||||
Now supports 5 types of connectors:
|
||||
Now supports 6 types of connectors:
|
||||
|
||||
- **ExampleConnector**: refer to [examples/offline_inference/disaggregated-prefill-v1/run.sh](../../examples/offline_inference/disaggregated-prefill-v1/run.sh) for the example usage of ExampleConnector disaggregated prefilling.
|
||||
- **LMCacheConnectorV1**: refer to [examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh](../../examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh) for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
|
||||
- **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md).
|
||||
- **P2pNcclConnector**: refer to [examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh](../../examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh) for the example usage of P2pNcclConnector disaggregated prefilling.
|
||||
- **MooncakeConnector**: refer to [examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh](../../examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh) for the example usage of ExampleConnector disaggregated prefilling. For detailed usage guide, see [MooncakeConnector Usage Guide](mooncake_connector_usage.md).
|
||||
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -31,11 +31,9 @@ vllm serve Qwen/Qwen2.5-7B-Instruct --port 8020 --kv-transfer-config '{"kv_conne
|
||||
### Proxy
|
||||
|
||||
```bash
|
||||
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --prefiller-host 192.168.0.2 --prefiller-port 8010 --decoder-host 192.168.0.3 --decoder-port 8020
|
||||
python examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py --prefill http://192.168.0.2:8010 --decode http://192.168.0.3:8020
|
||||
```
|
||||
|
||||
> NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future.
|
||||
|
||||
Now you can send requests to the proxy server through port 8000.
|
||||
|
||||
## Environment Variables
|
||||
@@ -43,16 +41,29 @@ Now you can send requests to the proxy server through port 8000.
|
||||
- `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server
|
||||
- Default: 8998
|
||||
- Required only for prefiller instances
|
||||
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank
|
||||
- Used for the decoder notifying the prefiller
|
||||
- For headless instances, must be the same as the master instance
|
||||
- Each instance needs a unique port on its host; using the same port number across different hosts is fine
|
||||
|
||||
- `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional)
|
||||
- Default: 480
|
||||
- If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
|
||||
|
||||
## KV Role Options
|
||||
## KV Transfer Config
|
||||
|
||||
### KV Role Options
|
||||
|
||||
- **kv_producer**: For prefiller instances that generate KV caches
|
||||
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
|
||||
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
|
||||
|
||||
### kv_connector_extra_config
|
||||
|
||||
- **num_workers**: Size of thread pool for one prefiller worker to transfer KV caches by mooncake. (default 10)
|
||||
- **mooncake_protocol**: Mooncake connector protocol. (default "rdma")
|
||||
|
||||
## Example Scripts/Code
|
||||
|
||||
Refer to these example scripts in the vLLM repository:
|
||||
|
||||
- [run_mooncake_connector.sh](../../examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh)
|
||||
- [mooncake_connector_proxy.py](../../examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py)
|
||||
|
||||
@@ -6,3 +6,4 @@ This example contains scripts that demonstrate the disaggregated serving feature
|
||||
|
||||
- `disagg_proxy_demo.py` - Demonstrates XpYd (X prefill instances, Y decode instances).
|
||||
- `kv_events.sh` - Demonstrates KV cache event publishing.
|
||||
- `mooncake_connector` - A proxy demo for MooncakeConnector.
|
||||
|
||||
@@ -0,0 +1,376 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import itertools
|
||||
import os
|
||||
import urllib
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
def maybe_wrap_ipv6_address(address: str) -> str:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return f"[{address}]"
|
||||
except ValueError:
|
||||
return address
|
||||
|
||||
|
||||
def make_http_path(host: str, port: int) -> str:
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def prefiller_cycle(prefill_clients: list[Any]):
|
||||
while True:
|
||||
for prefill_client in prefill_clients:
|
||||
for i in range(prefill_client["dp_size"]):
|
||||
yield prefill_client, i
|
||||
|
||||
|
||||
async def get_prefiller_info(prefill_clients: list, ready: asyncio.Event):
|
||||
for prefill_client in prefill_clients:
|
||||
while True:
|
||||
try:
|
||||
# Wait for prefill service to be ready
|
||||
response = await prefill_client["client"].get("/health")
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
response = await prefill_client["client"].get(
|
||||
prefill_client["bootstrap_addr"] + "/query"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
break
|
||||
|
||||
for dp_rank, dp_entry in data.items():
|
||||
prefill_client["dp_engine_id"][int(dp_rank)] = dp_entry["engine_id"]
|
||||
dp_size = len(data)
|
||||
prefill_client["dp_size"] = dp_size
|
||||
print(f"Inited prefiller {prefill_client['url']} with dp_size={dp_size}")
|
||||
|
||||
ready.set()
|
||||
print("All prefiller instances are ready.")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Lifespan context manager to handle startup and shutdown events.
|
||||
"""
|
||||
# Startup: Initialize client pools for prefiller and decoder services
|
||||
app.state.prefill_clients = []
|
||||
app.state.decode_clients = []
|
||||
app.state.ready = asyncio.Event()
|
||||
|
||||
# Create prefill clients
|
||||
for i, (url, bootstrap_port) in enumerate(global_args.prefill):
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
app.state.prefill_clients.append(
|
||||
{
|
||||
"client": httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=None,
|
||||
max_keepalive_connections=None,
|
||||
),
|
||||
),
|
||||
"url": url,
|
||||
"bootstrap_addr": make_http_path(hostname, bootstrap_port or 8998),
|
||||
"dp_engine_id": {},
|
||||
}
|
||||
)
|
||||
|
||||
# Create decode clients
|
||||
for i, url in enumerate(global_args.decode):
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||
app.state.decode_clients.append(
|
||||
{
|
||||
"client": httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=None,
|
||||
max_keepalive_connections=None,
|
||||
),
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
asyncio.create_task(get_prefiller_info(app.state.prefill_clients, app.state.ready))
|
||||
|
||||
# Initialize round-robin iterators
|
||||
app.state.prefill_iterator = prefiller_cycle(app.state.prefill_clients)
|
||||
app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))
|
||||
|
||||
print(
|
||||
f"Got {len(app.state.prefill_clients)} prefill clients "
|
||||
f"and {len(app.state.decode_clients)} decode clients."
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: Close all clients
|
||||
for client_info in app.state.prefill_clients:
|
||||
await client_info["client"].aclose()
|
||||
|
||||
for client_info in app.state.decode_clients:
|
||||
await client_info["client"].aclose()
|
||||
|
||||
|
||||
# Update FastAPI app initialization to use lifespan
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
# Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
|
||||
# For prefiller instances
|
||||
parser.add_argument(
|
||||
"--prefill",
|
||||
nargs="+",
|
||||
action="append",
|
||||
dest="prefill_raw",
|
||||
metavar=("URL", "bootstrap_port"),
|
||||
help=(
|
||||
"Prefill server URL and optional bootstrap port. "
|
||||
"Can be specified multiple times. "
|
||||
"Format: --prefill URL [BOOTSTRAP_PORT]. "
|
||||
"BOOTSTRAP_PORT can be a port number, "
|
||||
"'none', or omitted (defaults to none)."
|
||||
),
|
||||
)
|
||||
|
||||
# For decoder instances
|
||||
parser.add_argument(
|
||||
"--decode",
|
||||
nargs=1,
|
||||
action="append",
|
||||
dest="decode_raw",
|
||||
metavar=("URL",),
|
||||
help="Decode server URL. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.prefill = _parse_prefill_urls(args.prefill_raw)
|
||||
args.decode = _parse_decode_urls(args.decode_raw)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
# From sglang router_args.py
|
||||
def _parse_prefill_urls(prefill_list):
|
||||
"""Parse prefill URLs from --prefill arguments.
|
||||
|
||||
Format: --prefill URL [BOOTSTRAP_PORT]
|
||||
Example:
|
||||
--prefill http://prefill1:8080 9000 # With bootstrap port
|
||||
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
|
||||
--prefill http://prefill3:8080 # Defaults to no bootstrap port
|
||||
"""
|
||||
if not prefill_list:
|
||||
return []
|
||||
|
||||
prefill_urls = []
|
||||
for prefill_args in prefill_list:
|
||||
url = prefill_args[0]
|
||||
|
||||
# Handle optional bootstrap port
|
||||
if len(prefill_args) >= 2:
|
||||
bootstrap_port_str = prefill_args[1]
|
||||
# Handle 'none' as None
|
||||
if bootstrap_port_str.lower() == "none":
|
||||
bootstrap_port = None
|
||||
else:
|
||||
try:
|
||||
bootstrap_port = int(bootstrap_port_str)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" # noqa: E501
|
||||
) from e
|
||||
else:
|
||||
# No bootstrap port specified, default to None
|
||||
bootstrap_port = None
|
||||
|
||||
prefill_urls.append((url, bootstrap_port))
|
||||
|
||||
return prefill_urls
|
||||
|
||||
|
||||
def _parse_decode_urls(decode_list):
|
||||
"""Parse decode URLs from --decode arguments.
|
||||
|
||||
Format: --decode URL
|
||||
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||
"""
|
||||
if not decode_list:
|
||||
return []
|
||||
|
||||
# decode_list is a list of single-element lists due to nargs=1
|
||||
return [url[0] for url in decode_list]
|
||||
|
||||
|
||||
def get_next_client(app, service_type: str):
|
||||
"""
|
||||
Get the next client in round-robin fashion.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app instance
|
||||
service_type: Either 'prefill' or 'decode'
|
||||
|
||||
Returns:
|
||||
The next client to use
|
||||
"""
|
||||
if service_type == "prefill":
|
||||
return next(app.state.prefill_iterator)
|
||||
elif service_type == "decode":
|
||||
client_idx = next(app.state.decode_iterator)
|
||||
return app.state.decode_clients[client_idx]
|
||||
else:
|
||||
raise ValueError(f"Unknown service type: {service_type}")
|
||||
|
||||
|
||||
async def send_request_to_service(
|
||||
client_info: dict, dp_rank: int, endpoint: str, req_data: dict, request_id: str
|
||||
):
|
||||
"""
|
||||
Send a request to a service using a client from the pool.
|
||||
"""
|
||||
req_data = req_data.copy()
|
||||
req_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "max_completion_tokens" in req_data:
|
||||
req_data["max_completion_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
"X-data-parallel-rank": str(dp_rank),
|
||||
}
|
||||
|
||||
response = await client_info["client"].post(
|
||||
endpoint, json=req_data, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# CRITICAL: Release connection back to pool
|
||||
await response.aclose()
|
||||
|
||||
|
||||
async def stream_service_response(
|
||||
prefill_client_info: dict,
|
||||
prefill_dp_rank: int,
|
||||
decode_client_info: dict,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
):
|
||||
"""
|
||||
Asynchronously stream response from a service using a client from the pool.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
|
||||
req_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"remote_bootstrap_addr": prefill_client_info["bootstrap_addr"],
|
||||
"remote_engine_id": prefill_client_info["dp_engine_id"][prefill_dp_rank],
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
|
||||
async with decode_client_info["client"].stream(
|
||||
"POST", endpoint, json=req_data, headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
if not app.state.ready.is_set():
|
||||
raise HTTPException(status_code=503, detail="Service Unavailable")
|
||||
|
||||
try:
|
||||
req_data = await request.json()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Get the next prefill client in round-robin fashion
|
||||
prefill_client_info, prefill_dp_rank = get_next_client(request.app, "prefill")
|
||||
|
||||
# Send request to prefill service
|
||||
asyncio.create_task(
|
||||
send_request_to_service(
|
||||
prefill_client_info, prefill_dp_rank, api, req_data, request_id
|
||||
)
|
||||
)
|
||||
|
||||
decode_client_info = get_next_client(request.app, "decode")
|
||||
|
||||
# Stream response from decode service
|
||||
async def generate_stream():
|
||||
async for chunk in stream_service_response(
|
||||
prefill_client_info,
|
||||
prefill_dp_rank,
|
||||
decode_client_info,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completions(request: Request):
|
||||
return await _handle_completions("/v1/completions", request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completions(request: Request):
|
||||
return await _handle_completions("/v1/chat/completions", request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
@@ -0,0 +1,222 @@
|
||||
#!/bin/bash
|
||||
|
||||
# =============================================================================
|
||||
# vLLM Disaggregated Serving Script for Mooncake Connector
|
||||
# =============================================================================
|
||||
# This script demonstrates disaggregated prefill and decode serving using
|
||||
# Mooncake Connector.
|
||||
#
|
||||
# Configuration can be customized via environment variables:
|
||||
# MODEL: Model to serve
|
||||
# PREFILL_GPUS: Comma-separated GPU IDs for prefill servers
|
||||
# DECODE_GPUS: Comma-separated GPU IDs for decode servers
|
||||
# PREFILL_PORTS: Comma-separated ports for prefill servers
|
||||
# BOOTSTRAP_PORTS: Bootstrap server port launched by prefill servers
|
||||
# DECODE_PORTS: Comma-separated ports for decode servers
|
||||
# PROXY_PORT: Proxy server port used to setup P/D disaggregated connection.
|
||||
# TIMEOUT_SECONDS: Server startup timeout
|
||||
# =============================================================================
|
||||
|
||||
# Configuration - can be overridden via environment variables
|
||||
MODEL=${MODEL:-Qwen/Qwen2.5-7B-Instruct}
|
||||
TIMEOUT_SECONDS=${TIMEOUT_SECONDS:-1200}
|
||||
PROXY_PORT=${PROXY_PORT:-8000}
|
||||
|
||||
PREFILL_GPUS=${PREFILL_GPUS:-0}
|
||||
DECODE_GPUS=${DECODE_GPUS:-1}
|
||||
PREFILL_PORTS=${PREFILL_PORTS:-8010}
|
||||
BOOTSTRAP_PORTS=${BOOTSTRAP_PORTS:-8998}
|
||||
DECODE_PORTS=${DECODE_PORTS:-8020}
|
||||
|
||||
echo "Warning: Mooncake Connector support for vLLM v1 is experimental and subject to change."
|
||||
echo ""
|
||||
echo "Architecture Configuration:"
|
||||
echo " Model: $MODEL"
|
||||
echo " Prefill GPUs: $PREFILL_GPUS, Ports: $PREFILL_PORTS, Bootstrap Port:$BOOTSTRAP_PORTS"
|
||||
echo " Decode GPUs: $DECODE_GPUS, Ports: $DECODE_PORTS"
|
||||
echo " Proxy Port: $PROXY_PORT"
|
||||
echo " Timeout: ${TIMEOUT_SECONDS}s"
|
||||
echo ""
|
||||
|
||||
PIDS=()
|
||||
|
||||
# Switch to the directory of the current script
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
check_required_files() {
|
||||
local files=("mooncake_connector_proxy.py")
|
||||
for file in "${files[@]}"; do
|
||||
if [[ ! -f "$file" ]]; then
|
||||
echo "Required file $file not found in $(pwd)"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
check_hf_token() {
|
||||
if [ -z "$HF_TOKEN" ]; then
|
||||
echo "HF_TOKEN is not set. Please set it to your Hugging Face token."
|
||||
echo "Example: export HF_TOKEN=your_token_here"
|
||||
exit 1
|
||||
fi
|
||||
if [[ "$HF_TOKEN" != hf_* ]]; then
|
||||
echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token."
|
||||
exit 1
|
||||
fi
|
||||
echo "HF_TOKEN is set and valid."
|
||||
}
|
||||
|
||||
check_num_gpus() {
|
||||
# Check if the number of GPUs are >=2 via nvidia-smi
|
||||
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||
if [ "$num_gpus" -lt 2 ]; then
|
||||
echo "You need at least 2 GPUs to run disaggregated prefill."
|
||||
exit 1
|
||||
else
|
||||
echo "Found $num_gpus GPUs."
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_python_library_installed() {
|
||||
echo "Checking if $1 is installed..."
|
||||
if ! python3 -c "import $1" > /dev/null 2>&1; then
|
||||
echo "$1 is not installed. Please install it via pip install $1."
|
||||
exit 1
|
||||
else
|
||||
echo "$1 is installed."
|
||||
fi
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
echo "Stopping everything…"
|
||||
trap - INT TERM # prevent re-entrancy
|
||||
pkill -9 -f "mooncake_connector_proxy.py"
|
||||
kill -- -$$ # negative PID == "this whole process-group"
|
||||
wait # reap children so we don't leave zombies
|
||||
exit 0
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
local timeout_seconds=$TIMEOUT_SECONDS
|
||||
local start_time=$(date +%s)
|
||||
|
||||
echo "Waiting for server on port $port..."
|
||||
|
||||
while true; do
|
||||
if curl -s "localhost:${port}/v1/completions" > /dev/null; then
|
||||
echo "Server on port $port is ready."
|
||||
return 0
|
||||
fi
|
||||
|
||||
local now=$(date +%s)
|
||||
if (( now - start_time >= timeout_seconds )); then
|
||||
echo "Timeout waiting for server on port $port"
|
||||
return 1
|
||||
fi
|
||||
|
||||
sleep 1
|
||||
done
|
||||
}
|
||||
|
||||
main() {
|
||||
check_required_files
|
||||
check_hf_token
|
||||
check_num_gpus
|
||||
ensure_python_library_installed vllm
|
||||
ensure_python_library_installed mooncake.engine
|
||||
|
||||
trap cleanup INT
|
||||
trap cleanup USR1
|
||||
trap cleanup TERM
|
||||
|
||||
echo "Launching disaggregated serving components..."
|
||||
echo "Please check the log files for detailed output:"
|
||||
echo " - prefill*.log: Prefill server logs"
|
||||
echo " - decode*.log: Decode server logs"
|
||||
echo " - proxy.log: Proxy server log"
|
||||
|
||||
# Parse GPU and port arrays
|
||||
IFS=',' read -ra PREFILL_GPU_ARRAY <<< "$PREFILL_GPUS"
|
||||
IFS=',' read -ra DECODE_GPU_ARRAY <<< "$DECODE_GPUS"
|
||||
IFS=',' read -ra PREFILL_PORT_ARRAY <<< "$PREFILL_PORTS"
|
||||
IFS=',' read -ra BOOTSTRAP_PORT_ARRAY <<< "$BOOTSTRAP_PORTS"
|
||||
IFS=',' read -ra DECODE_PORT_ARRAY <<< "$DECODE_PORTS"
|
||||
|
||||
proxy_param=""
|
||||
|
||||
# =============================================================================
|
||||
# Launch Prefill Servers (X Producers)
|
||||
# =============================================================================
|
||||
echo ""
|
||||
echo "Starting ${#PREFILL_GPU_ARRAY[@]} prefill server(s)..."
|
||||
for i in "${!PREFILL_GPU_ARRAY[@]}"; do
|
||||
local gpu_id=${PREFILL_GPU_ARRAY[$i]}
|
||||
local port=${PREFILL_PORT_ARRAY[$i]}
|
||||
local bootstrap_port=${BOOTSTRAP_PORT_ARRAY[$i]}
|
||||
|
||||
echo " Prefill server $((i+1)): GPU $gpu_id, Port $port, Bootstrap Port $bootstrap_port"
|
||||
VLLM_MOONCAKE_BOOTSTRAP_PORT=$bootstrap_port CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \
|
||||
--port $port \
|
||||
--kv-transfer-config \
|
||||
"{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_producer\"}" > prefill$((i+1)).log 2>&1 &
|
||||
PIDS+=($!)
|
||||
proxy_param="${proxy_param} --prefill http://0.0.0.0:${port} $bootstrap_port"
|
||||
done
|
||||
|
||||
# =============================================================================
|
||||
# Launch Decode Servers (Y Decoders)
|
||||
# =============================================================================
|
||||
echo ""
|
||||
echo "Starting ${#DECODE_GPU_ARRAY[@]} decode server(s)..."
|
||||
for i in "${!DECODE_GPU_ARRAY[@]}"; do
|
||||
local gpu_id=${DECODE_GPU_ARRAY[$i]}
|
||||
local port=${DECODE_PORT_ARRAY[$i]}
|
||||
|
||||
echo " Decode server $((i+1)): GPU $gpu_id, Port $port"
|
||||
CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \
|
||||
--port $port \
|
||||
--kv-transfer-config \
|
||||
"{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_consumer\"}" > decode$((i+1)).log 2>&1 &
|
||||
PIDS+=($!)
|
||||
proxy_param="${proxy_param} --decode http://0.0.0.0:${port}"
|
||||
done
|
||||
|
||||
# =============================================================================
|
||||
# Launch Proxy Server
|
||||
# =============================================================================
|
||||
echo ""
|
||||
echo "Starting proxy server on port $PROXY_PORT..."
|
||||
python3 mooncake_connector_proxy.py $proxy_param --port $PROXY_PORT > proxy.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# =============================================================================
|
||||
# Wait for All Servers to Start
|
||||
# =============================================================================
|
||||
echo ""
|
||||
echo "Waiting for all servers to start..."
|
||||
for port in "${PREFILL_PORT_ARRAY[@]}" "${DECODE_PORT_ARRAY[@]}"; do
|
||||
if ! wait_for_server $port; then
|
||||
echo "Failed to start server on port $port"
|
||||
cleanup
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "All servers are up. Starting benchmark..."
|
||||
|
||||
# =============================================================================
|
||||
# Run Benchmark
|
||||
# =============================================================================
|
||||
vllm bench serve --port $PROXY_PORT --seed $(date +%s) \
|
||||
--backend vllm --model $MODEL \
|
||||
--dataset-name random --random-input-len 7500 --random-output-len 200 \
|
||||
--num-prompts 200 --burstiness 100 --request-rate 2 | tee benchmark.log
|
||||
|
||||
echo "Benchmarking done. Cleaning up..."
|
||||
|
||||
cleanup
|
||||
}
|
||||
|
||||
main
|
||||
@@ -198,6 +198,6 @@ KVConnectorFactory.register_connector(
|
||||
)
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
|
||||
"MooncakeConnector",
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId
|
||||
from vllm.logger import init_logger
|
||||
|
||||
WorkerAddr = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RegisterWorkerPayload(BaseModel):
|
||||
engine_id: EngineId
|
||||
dp_rank: int
|
||||
tp_rank: int
|
||||
pp_rank: int
|
||||
addr: WorkerAddr
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineEntry:
|
||||
engine_id: EngineId
|
||||
# {tp_rank: {pp_rank: worker_addr}}
|
||||
worker_addr: dict[int, dict[int, WorkerAddr]]
|
||||
|
||||
|
||||
class MooncakeBootstrapServer:
|
||||
"""
|
||||
A centralized server running on the global rank 0 prefiller worker.
|
||||
Prefiller workers register their connection info (IP, port, ranks) here.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, host: str, port: int):
|
||||
self.workers: dict[int, EngineEntry] = {}
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.app = FastAPI()
|
||||
self._register_routes()
|
||||
self.server_thread: threading.Thread | None = None
|
||||
self.server: uvicorn.Server | None = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def _register_routes(self):
|
||||
# All methods are async. No need to use lock to protect data.
|
||||
self.app.post("/register")(self.register_worker)
|
||||
self.app.get("/query", response_model=dict[int, EngineEntry])(self.query)
|
||||
|
||||
def start(self):
|
||||
if self.server_thread:
|
||||
return
|
||||
|
||||
config = uvicorn.Config(app=self.app, host=self.host, port=self.port)
|
||||
self.server = uvicorn.Server(config=config)
|
||||
self.server_thread = threading.Thread(
|
||||
target=self.server.run, name="mooncake_bootstrap_server", daemon=True
|
||||
)
|
||||
self.server_thread.start()
|
||||
while not self.server.started:
|
||||
time.sleep(0.1) # Wait for the server to start
|
||||
logger.info("Mooncake Bootstrap Server started at %s:%d", self.host, self.port)
|
||||
|
||||
def shutdown(self):
|
||||
if self.server_thread is None or self.server is None or not self.server.started:
|
||||
return
|
||||
|
||||
self.server.should_exit = True
|
||||
self.server_thread.join()
|
||||
logger.info("Mooncake Bootstrap Server stopped.")
|
||||
|
||||
async def register_worker(self, payload: RegisterWorkerPayload):
|
||||
"""Handles registration of a prefiller worker."""
|
||||
if payload.dp_rank not in self.workers:
|
||||
self.workers[payload.dp_rank] = EngineEntry(
|
||||
engine_id=payload.engine_id,
|
||||
worker_addr={},
|
||||
)
|
||||
|
||||
dp_entry = self.workers[payload.dp_rank]
|
||||
if dp_entry.engine_id != payload.engine_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Engine ID mismatch for dp_rank={payload.dp_rank}: "
|
||||
f"expected {dp_entry.engine_id}, got {payload.engine_id}"
|
||||
),
|
||||
)
|
||||
if payload.tp_rank not in dp_entry.worker_addr:
|
||||
dp_entry.worker_addr[payload.tp_rank] = {}
|
||||
|
||||
tp_entry = dp_entry.worker_addr[payload.tp_rank]
|
||||
if payload.pp_rank in tp_entry:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Worker with dp_rank={payload.dp_rank}, "
|
||||
f"tp_rank={payload.tp_rank}, pp_rank={payload.pp_rank} "
|
||||
f"is already registered at "
|
||||
f"{tp_entry[payload.pp_rank]}, "
|
||||
f"but still want to register at {payload.addr}"
|
||||
),
|
||||
)
|
||||
|
||||
tp_entry[payload.pp_rank] = payload.addr
|
||||
logger.debug(
|
||||
"Registered worker: engine_id=%s, dp_rank=%d, tp_rank=%d, pp_rank=%d at %s",
|
||||
payload.engine_id,
|
||||
payload.dp_rank,
|
||||
payload.tp_rank,
|
||||
payload.pp_rank,
|
||||
payload.addr,
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
async def query(self) -> dict[int, EngineEntry]:
|
||||
return self.workers
|
||||
Reference in New Issue
Block a user