[Core][KVConnector] Support HMA+NixlConnector (#35758)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -12,6 +12,7 @@ tp_configs=(
|
||||
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
|
||||
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
|
||||
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
|
||||
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" # SW model
|
||||
)
|
||||
dp_ep_configs=(
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
|
||||
@@ -26,6 +27,14 @@ else
|
||||
configs=("${tp_configs[@]}")
|
||||
fi
|
||||
|
||||
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
|
||||
# Append ENABLE_HMA_FLAG=1 to each config in the selected array
|
||||
echo "ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config"
|
||||
for i in "${!configs[@]}"; do
|
||||
configs[$i]="ENABLE_HMA_FLAG=1 ${configs[$i]}"
|
||||
done
|
||||
fi
|
||||
|
||||
run_tests() {
|
||||
local label=$1
|
||||
local extra_args=$2
|
||||
|
||||
@@ -5,6 +5,12 @@ set -xe
|
||||
KV_BUFFER_DEVICE="cuda" # Default to cuda
|
||||
ATTENTION_BACKEND="" # Default to empty (use vllm default)
|
||||
CROSS_LAYERS_BLOCKS="False"
|
||||
ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector)
|
||||
# Check for ENABLE_HMA_FLAG environment variable
|
||||
if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then
|
||||
ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager"
|
||||
fi
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--kv_buffer_device)
|
||||
@@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
||||
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||
echo "Using attention backend: $ATTENTION_BACKEND"
|
||||
fi
|
||||
if [[ -n "$ENABLE_HMA_VAR" ]]; then
|
||||
echo "HMA (Hybrid KV Cache Manager) enabled"
|
||||
fi
|
||||
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
|
||||
echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS"
|
||||
fi
|
||||
|
||||
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
|
||||
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
|
||||
@@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
||||
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
|
||||
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
|
||||
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
|
||||
# Comma-separated extra args for vllm serve (e.g. --max-model-len,2048)
|
||||
VLLM_SERVE_EXTRA_ARGS=${VLLM_SERVE_EXTRA_ARGS:-}
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
@@ -151,14 +165,24 @@ run_tests_for_model() {
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
|
||||
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
|
||||
for arg in "${extra_args[@]}"; do
|
||||
BASE_CMD="${BASE_CMD} $arg"
|
||||
done
|
||||
fi
|
||||
|
||||
# Add attention backend config if specified
|
||||
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
|
||||
fi
|
||||
|
||||
# Add HMA flag if specified
|
||||
if [[ -n "$ENABLE_HMA_VAR" ]]; then
|
||||
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
|
||||
fi
|
||||
|
||||
FULL_CMD="$BASE_CMD"
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Store host and port for proxy configuration
|
||||
@@ -193,12 +217,23 @@ run_tests_for_model() {
|
||||
--block-size ${DECODE_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
|
||||
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
|
||||
for arg in "${extra_args[@]}"; do
|
||||
BASE_CMD="${BASE_CMD} $arg"
|
||||
done
|
||||
fi
|
||||
|
||||
# Add attention backend config if specified
|
||||
if [[ -n "$ATTENTION_BACKEND" ]]; then
|
||||
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
|
||||
fi
|
||||
|
||||
# Add HMA flag if specified
|
||||
if [[ -n "$ENABLE_HMA_VAR" ]]; then
|
||||
BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR"
|
||||
fi
|
||||
|
||||
# DP-EP attention mode
|
||||
if [[ -z "$DP_EP" ]]; then
|
||||
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
|
||||
|
||||
@@ -17,6 +17,7 @@ EXPECTED_VALUES = {
|
||||
"deepseek-ai/deepseek-vl2-small": 0.59,
|
||||
"deepseek-ai/deepseek-vl2-tiny": 0.19,
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
|
||||
"google/gemma-3-4b-it": 0.74,
|
||||
}
|
||||
|
||||
SIMPLE_PROMPT = (
|
||||
|
||||
@@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
from .utils import (
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
make_kv_cache_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
@@ -263,7 +268,7 @@ def test_basic_interface():
|
||||
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids,
|
||||
req_meta.local_block_ids[0],
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
@@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init):
|
||||
|
||||
# Prefill connector will register KV cache to populate proper handshake
|
||||
# metadata.
|
||||
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
prefill_connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
@@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init):
|
||||
do_remote_decode=True,
|
||||
)
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
|
||||
request, [0, 1, 2]
|
||||
delay, kv_connector_metadata = (
|
||||
scheduler.get_kv_connector().request_finished_all_groups(
|
||||
request, ([0, 1, 2],)
|
||||
)
|
||||
)
|
||||
assert delay
|
||||
|
||||
# Decode connector will be able to create handshake with the prefill connector.
|
||||
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
decode_connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
decode_connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Here we are testing the retrieval of NIXLAgentMetadata.
|
||||
@@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
|
||||
def __init__(
|
||||
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
|
||||
self,
|
||||
*args,
|
||||
hand_shake_latency: float = 1.8,
|
||||
kv_cache_layout="HND",
|
||||
kv_cache_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
if kv_cache_config is None:
|
||||
kv_cache_config = make_kv_cache_config(block_size=16)
|
||||
super().__init__(*args, kv_cache_config=kv_cache_config, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
self.kv_cache_layout = kv_cache_layout
|
||||
# Mock register_kv_caches attribute needed for tests that do not call it.
|
||||
@@ -507,7 +525,9 @@ class TestNixlHandshake:
|
||||
request_id = "req_id"
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
@@ -528,13 +548,15 @@ class TestNixlHandshake:
|
||||
num_xfers -= 1
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
|
||||
local_block_ids=([num_xfers + 1, num_xfers + 2, num_xfers + 3],),
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [
|
||||
num_xfers + 4,
|
||||
num_xfers + 5,
|
||||
num_xfers + 6,
|
||||
],
|
||||
"remote_block_ids": (
|
||||
[
|
||||
num_xfers + 4,
|
||||
num_xfers + 5,
|
||||
num_xfers + 6,
|
||||
],
|
||||
),
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_request_id": f"prefill-{request_id}",
|
||||
"remote_host": "localhost",
|
||||
@@ -594,16 +616,18 @@ class TestNixlHandshake:
|
||||
vllm_config.parallel_config.tensor_parallel_size = decode_tp_size
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id
|
||||
)
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id="id",
|
||||
local_block_ids=[1, 2, 3],
|
||||
local_block_ids=([1, 2, 3],),
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_block_ids": ([4, 5, 6],),
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_request_id": "prefill-id",
|
||||
"remote_host": "localhost",
|
||||
@@ -652,7 +676,9 @@ class TestNixlHandshake:
|
||||
local_tp_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
|
||||
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
@@ -717,8 +743,12 @@ class TestNixlHandshake:
|
||||
p_tp_size = 2
|
||||
|
||||
# Build two separate connectors/workers to emulate P TP=2 ranks.
|
||||
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
conn_p0 = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
conn_p1 = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
conn_p0.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, conn_p0.engine_id, hand_shake_latency=0
|
||||
)
|
||||
@@ -815,7 +845,9 @@ class TestNixlHandshake:
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id
|
||||
)
|
||||
@@ -827,9 +859,9 @@ class TestNixlHandshake:
|
||||
for i in range(total_reqs):
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=f"id_{i}",
|
||||
local_block_ids=[1, 2, 3],
|
||||
local_block_ids=([1, 2, 3],),
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_block_ids": ([4, 5, 6],),
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_request_id": f"prefill-id-{i}",
|
||||
"remote_host": "localhost",
|
||||
@@ -884,7 +916,9 @@ class TestNixlHandshake:
|
||||
return_value=2,
|
||||
):
|
||||
# Initialize connector and worker (with fake NIXL wrapper)
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
@@ -934,7 +968,9 @@ class TestNixlHandshake:
|
||||
return_value=2,
|
||||
):
|
||||
# Initialize connector and worker (with fake NIXL wrapper)
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config,
|
||||
connector.engine_id,
|
||||
@@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
@@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=[1, 2, 3],
|
||||
local_block_ids=([1, 2, 3],),
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_block_ids": ([4, 5, 6],),
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_request_id": f"prefill-{request_id}",
|
||||
"remote_host": "localhost",
|
||||
@@ -1448,7 +1486,9 @@ def test_register_kv_caches(
|
||||
mock_get_attn_backend.return_value = backend_cls
|
||||
|
||||
# Create connector
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
@@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types(
|
||||
),
|
||||
): # noqa: E501
|
||||
# Create connector and replace its worker with a fake one for isolation
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
|
||||
# Verify get_reg_descs was called with the correct memory_type
|
||||
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
|
||||
@@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
scheduler = NixlConnectorScheduler(
|
||||
vllm_config, vllm_config.kv_transfer_config.engine_id
|
||||
vllm_config,
|
||||
vllm_config.kv_transfer_config.engine_id,
|
||||
make_kv_cache_config(block_size=16),
|
||||
)
|
||||
worker = NixlConnectorWorker(
|
||||
vllm_config,
|
||||
vllm_config.kv_transfer_config.engine_id,
|
||||
make_kv_cache_config(block_size=16),
|
||||
)
|
||||
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
|
||||
nixl_wrapper = worker.nixl_wrapper
|
||||
|
||||
with (
|
||||
@@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
# KVConnector Worker in P
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
@@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper):
|
||||
("transfer_exception", {"fail_transfer_exception": True}, True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_hma", [False, True])
|
||||
def test_transfer_failure_logging(
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
failure_type,
|
||||
wrapper_config,
|
||||
needs_get_finished,
|
||||
enable_hma,
|
||||
):
|
||||
"""Test that transfer failures are logged with structured context.
|
||||
|
||||
@@ -1897,9 +1949,16 @@ def test_transfer_failure_logging(
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
make_kv_cache_config(block_size=16, hma_enabled=enable_hma),
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0.0
|
||||
vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=0.0,
|
||||
kv_cache_config=connector._kv_cache_config,
|
||||
)
|
||||
|
||||
# Configure FailingNixlWrapper to fail in the specified way
|
||||
@@ -1910,8 +1969,17 @@ def test_transfer_failure_logging(
|
||||
|
||||
# For notification_failed, we need empty local blocks
|
||||
# (full cache hit path to trigger send_notif)
|
||||
local_blocks = [] if failure_type == "notification_failed" else [10, 11, 12]
|
||||
remote_blocks = [20, 21, 22]
|
||||
local_blocks: tuple[()] | tuple[list[int], ...]
|
||||
if enable_hma:
|
||||
# HMA enabled: multiple groups (FA + SW)
|
||||
local_blocks = (
|
||||
() if failure_type == "notification_failed" else ([10, 11, 12], [13, 14])
|
||||
)
|
||||
remote_blocks = [[20, 21, 22], [23, 24]]
|
||||
else:
|
||||
# HMA disabled: single group
|
||||
local_blocks = () if failure_type == "notification_failed" else ([10, 11, 12],)
|
||||
remote_blocks = [[20, 21, 22]]
|
||||
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req_to_recv(
|
||||
@@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
|
||||
"""Test that handshake failures mark blocks invalid and return via get_finished."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0.1
|
||||
)
|
||||
@@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=[1, 2, 3],
|
||||
local_block_ids=([1, 2, 3],),
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_block_ids": ([4, 5, 6],),
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_request_id": f"prefill-{request_id}",
|
||||
"remote_host": "localhost",
|
||||
@@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
|
||||
and return via get_finished."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
@@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=[7, 8, 9],
|
||||
local_block_ids=([7, 8, 9],),
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [10, 11, 12],
|
||||
"remote_block_ids": ([10, 11, 12],),
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_request_id": f"prefill-{request_id}",
|
||||
"remote_host": "localhost",
|
||||
@@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation(
|
||||
"enforce_handshake_compat": enforce_handshake_compat
|
||||
},
|
||||
)
|
||||
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
|
||||
decode_connector = NixlConnector(
|
||||
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
decode_worker = decode_connector.connector_worker
|
||||
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
@@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
|
||||
model="facebook/opt-125m",
|
||||
block_size=16,
|
||||
)
|
||||
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
|
||||
decode_connector = NixlConnector(
|
||||
local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
)
|
||||
decode_worker = decode_connector.connector_worker
|
||||
|
||||
backend = get_current_attn_backend(local_vllm_config)
|
||||
|
||||
203
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Normal file
203
tests/v1/kv_connector/unit/test_nixl_connector_hma.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
FullAttentionManager,
|
||||
SlidingWindowManager,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
create_vllm_config,
|
||||
make_kv_cache_config,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"hma_enabled,expected_sw_sizes",
|
||||
[
|
||||
# HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128)
|
||||
(True, [0, 128 + 1]),
|
||||
# HMA disabled: only FullAttentionSpec (0)
|
||||
(False, [0]),
|
||||
],
|
||||
)
|
||||
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform")
|
||||
def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes):
|
||||
"""Test sw_sizes is correctly computed based on HMA enabled/disabled."""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorScheduler,
|
||||
)
|
||||
|
||||
mock_platform.device_type = "cpu"
|
||||
|
||||
block_size = 16
|
||||
vllm_config = create_vllm_config(block_size=block_size)
|
||||
# SW 2048 tokens=>128 blocks
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=block_size, hma_enabled=hma_enabled, sw_size=2048
|
||||
)
|
||||
|
||||
scheduler = NixlConnectorScheduler(
|
||||
vllm_config=vllm_config,
|
||||
engine_id="test-engine",
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
# in number of blocks
|
||||
assert scheduler.blocks_per_sw == expected_sw_sizes, (
|
||||
f"Expected sw_sizes={expected_sw_sizes}, got {scheduler.blocks_per_sw}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_logical_to_kernel_block_ids_with_hma():
|
||||
"""Test _logical_to_kernel_block_ids expands blocks when HMA is enabled.
|
||||
|
||||
When HMA is enabled, the logical block size may differ from the kernel
|
||||
block size. Each logical block maps to multiple kernel blocks.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorWorker,
|
||||
)
|
||||
|
||||
# Create a mock worker with just the required attributes
|
||||
# (use __new__ to skip __init__)
|
||||
worker = object.__new__(NixlConnectorWorker)
|
||||
|
||||
# Simulate HMA scenario: logical block size = 32, kernel block size = 16
|
||||
# So each logical block maps to 2 kernel blocks eg [0]->[0,1]
|
||||
worker._physical_blocks_per_logical_kv_block = 2
|
||||
|
||||
# Test conversion: FA + SW group
|
||||
logical_block_ids = [[0, 1, 2], [3, 4]]
|
||||
kernel_block_ids = worker._logical_to_kernel_block_ids(logical_block_ids)
|
||||
|
||||
expected_kernel_block_ids = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]]
|
||||
assert kernel_block_ids == expected_kernel_block_ids, (
|
||||
f"Expected {expected_kernel_block_ids}, got {kernel_block_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name, sw_size", [("google/gemma-3-1b-it", 512)])
|
||||
def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
|
||||
"""Test that a prefill instance returns fewer "remote blocks" for the SWA groups
|
||||
when sequence exceeds the sliding window.
|
||||
"""
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
block_size = 16
|
||||
llm_kwargs = {
|
||||
"model": model_name,
|
||||
"enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.5,
|
||||
"kv_transfer_config": kv_transfer_config,
|
||||
"max_model_len": 2048,
|
||||
# NOTE: Make sure HMA is enabled
|
||||
"disable_hybrid_kv_cache_manager": False,
|
||||
"max_num_batched_tokens": 1024,
|
||||
"enable_prefix_caching": False,
|
||||
"block_size": block_size,
|
||||
}
|
||||
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
def run_hma_test(llm: LLM):
|
||||
remote_prefill_opts = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
}
|
||||
# Simulate sidecar request
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=1,
|
||||
extra_args={"kv_transfer_params": remote_prefill_opts},
|
||||
)
|
||||
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
|
||||
kv_managers = scheduler.kv_cache_manager.coordinator.single_type_managers
|
||||
# HMA enabled with FA + SWA groups
|
||||
assert len(kv_managers) > 2
|
||||
for kv_manager in kv_managers:
|
||||
assert isinstance(kv_manager, (SlidingWindowManager, FullAttentionManager))
|
||||
req_to_blocks = kv_managers[0].req_to_blocks
|
||||
assert len(req_to_blocks) == 0
|
||||
|
||||
# Process some request with length exceeding the sliding window
|
||||
outputs = llm.generate(["hi" * 1401], sampling_params)
|
||||
kv_params = outputs[0].kv_transfer_params
|
||||
|
||||
# +1 to account for overlapping window across blocks.
|
||||
expected_num_remote_blocks = sw_size // block_size + 1
|
||||
remote_block_ids = kv_params["remote_block_ids"]
|
||||
assert (
|
||||
len(remote_block_ids[0])
|
||||
== expected_num_remote_blocks
|
||||
< len(remote_block_ids[-1])
|
||||
)
|
||||
for group_block_ids in remote_block_ids[:-1]:
|
||||
assert len(group_block_ids) == expected_num_remote_blocks
|
||||
|
||||
def run_test_and_cleanup():
|
||||
llm = LLM(**llm_kwargs)
|
||||
try:
|
||||
run_hma_test(llm)
|
||||
finally:
|
||||
llm.llm_engine.engine_core.shutdown()
|
||||
|
||||
run_test_and_cleanup()
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_nixl_metadata_hma_block_ids_structure():
|
||||
"""
|
||||
Test that NixlConnectorMetadata correctly stores block IDs for multiple
|
||||
KV cache groups when HMA is enabled.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorMetadata,
|
||||
)
|
||||
|
||||
metadata = NixlConnectorMetadata()
|
||||
|
||||
# Add request with block IDs for 2 groups (FA + SW)
|
||||
fa_blocks = [0, 1, 2, 3, 4, 5, 6, 7] # 8 blocks for FA
|
||||
sw_blocks = [8, 9, 10, 11] # 4 blocks for SW (clipped)
|
||||
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id="test-req-hma",
|
||||
local_block_ids=(fa_blocks, sw_blocks),
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": ([10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21]),
|
||||
"remote_engine_id": "remote-engine",
|
||||
"remote_request_id": "prefill-test-req-hma",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"tp_size": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert "test-req-hma" in metadata.reqs_to_recv
|
||||
req_meta = metadata.reqs_to_recv["test-req-hma"]
|
||||
|
||||
# Verify local block IDs structure
|
||||
assert len(req_meta.local_block_ids) == 2
|
||||
assert list(req_meta.local_block_ids[0]) == fa_blocks
|
||||
assert list(req_meta.local_block_ids[1]) == sw_blocks
|
||||
|
||||
# Verify remote block IDs structure
|
||||
assert req_meta.remote is not None
|
||||
assert len(req_meta.remote.block_ids) == 2
|
||||
assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17]
|
||||
assert list(req_meta.remote.block_ids[1]) == [18, 19, 20, 21]
|
||||
@@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle():
|
||||
|
||||
# Ensure we send all block ids, including the partial blocks,
|
||||
# even if there is a cache hit.
|
||||
assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)
|
||||
# remote_block_ids is BlockIds (tuple of lists); sum block counts across groups.
|
||||
num_remote_blocks = sum(len(g) for g in kv_transfer_params["remote_block_ids"])
|
||||
assert num_remote_blocks == (NUM_EXTERNAL_FULL_BLOCKS + 1)
|
||||
|
||||
# STEP (2): Ensure it is freed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
@@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
@@ -142,24 +143,26 @@ def create_vllm_config(
|
||||
def create_scheduler(
|
||||
vllm_config: VllmConfig,
|
||||
num_blocks: int = 10000,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
) -> Scheduler:
|
||||
"""Initialize Scheduler For Testing."""
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
if kv_cache_config is None:
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
@@ -412,3 +415,38 @@ KVConnectorFactory.register_connector(
|
||||
KVConnectorFactory.register_connector(
|
||||
"MockKVConnector", __name__, MockKVConnector.__name__
|
||||
)
|
||||
|
||||
|
||||
def make_kv_cache_config(
|
||||
block_size: int,
|
||||
hma_enabled: bool = False,
|
||||
sw_size: int = 128,
|
||||
num_blocks: int = 100,
|
||||
) -> KVCacheConfig:
|
||||
kv_cache_groups = [
|
||||
KVCacheGroupSpec(
|
||||
["layer0", "layer2"],
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=4,
|
||||
head_size=16,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
)
|
||||
]
|
||||
if hma_enabled:
|
||||
kv_cache_groups.append(
|
||||
KVCacheGroupSpec(
|
||||
["layer1", "layer3"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=4,
|
||||
head_size=16,
|
||||
dtype=torch.float16,
|
||||
sliding_window=sw_size,
|
||||
),
|
||||
)
|
||||
)
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups
|
||||
)
|
||||
|
||||
@@ -24,6 +24,9 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
EngineId = str
|
||||
# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow
|
||||
# mutability and are for connector internal use only.
|
||||
BlockIds = tuple[list[int], ...] | list[list[int]]
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
@@ -24,6 +23,7 @@ import zmq
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
BlockIds,
|
||||
EngineId,
|
||||
TpKVTopology,
|
||||
get_current_attn_backend,
|
||||
@@ -38,6 +38,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorHandshakeMetadata,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
@@ -53,10 +54,12 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -205,6 +208,7 @@ def compute_nixl_compatibility_hash(
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
is_hma_enabled = not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
|
||||
|
||||
factors = {
|
||||
# Version compatibility
|
||||
@@ -220,6 +224,7 @@ def compute_nixl_compatibility_hash(
|
||||
"attn_backend_name": attn_backend_name,
|
||||
"cache_dtype": str(cache_config.cache_dtype),
|
||||
"cross_layers_blocks": cross_layers_blocks,
|
||||
"is_hma_enabled": is_hma_enabled,
|
||||
}
|
||||
|
||||
compat_hash = hash_factors(factors)
|
||||
@@ -238,7 +243,7 @@ def compute_nixl_compatibility_hash(
|
||||
|
||||
@dataclass
|
||||
class RemoteMeta:
|
||||
block_ids: list[int]
|
||||
block_ids: BlockIds
|
||||
host: str
|
||||
port: int
|
||||
engine_id: str
|
||||
@@ -247,9 +252,9 @@ class RemoteMeta:
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
local_block_ids: BlockIds
|
||||
# To be used when logical block size does not match the kernel block size
|
||||
local_physical_block_ids: list[int]
|
||||
local_physical_block_ids: BlockIds
|
||||
tp_size: int
|
||||
remote: RemoteMeta | None = None
|
||||
|
||||
@@ -264,7 +269,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def _add_new_req(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
local_block_ids: BlockIds,
|
||||
kv_transfer_params: dict[str, Any],
|
||||
) -> ReqMeta:
|
||||
return ReqMeta(
|
||||
@@ -277,7 +282,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
def add_new_req_to_save(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
local_block_ids: BlockIds,
|
||||
kv_transfer_params: dict[str, Any],
|
||||
):
|
||||
self.reqs_to_save[request_id] = self._add_new_req(
|
||||
@@ -287,7 +292,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
def add_new_req_to_recv(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
local_block_ids: BlockIds,
|
||||
kv_transfer_params: dict[str, Any],
|
||||
):
|
||||
req = self._add_new_req(local_block_ids, kv_transfer_params)
|
||||
@@ -301,7 +306,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
self.reqs_to_recv[request_id] = req
|
||||
|
||||
|
||||
class NixlConnector(KVConnectorBase_V1):
|
||||
class NixlConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
backend = get_current_attn_backend(self._vllm_config)
|
||||
@@ -326,22 +331,27 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
assert vllm_config.kv_transfer_config.engine_id is not None
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
if isinstance(group.kv_cache_spec, MambaSpec):
|
||||
raise ValueError("NixlConnector does not support Mamba models.")
|
||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler: NixlConnectorScheduler | None = (
|
||||
NixlConnectorScheduler(vllm_config, self.engine_id)
|
||||
NixlConnectorScheduler(vllm_config, self.engine_id, kv_cache_config)
|
||||
)
|
||||
self.connector_worker: NixlConnectorWorker | None = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = NixlConnectorWorker(vllm_config, self.engine_id)
|
||||
self.connector_worker = NixlConnectorWorker(
|
||||
vllm_config, self.engine_id, kv_cache_config
|
||||
)
|
||||
|
||||
############################################################
|
||||
# Class Methods
|
||||
@@ -392,10 +402,10 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
@@ -518,10 +528,13 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
class NixlConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, engine_id: str, kv_cache_config: "KVCacheConfig"
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
self.side_channel_port = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
||||
@@ -534,8 +547,18 @@ class NixlConnectorScheduler:
|
||||
self.use_host_buffer = (
|
||||
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
||||
)
|
||||
self._is_hma_required = (
|
||||
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
|
||||
# Also handle unlikely SW-only model case instead of checking num_groups>1.
|
||||
and any(
|
||||
not isinstance(g.kv_cache_spec, FullAttentionSpec)
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
logger.info("Hybrid Memory Allocator is enabled with NIXL")
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||
@@ -545,7 +568,7 @@ class NixlConnectorScheduler:
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, BlockIds]] = {}
|
||||
self._reqs_need_save: dict[ReqId, Request] = {}
|
||||
# Reqs to send and their expiration time
|
||||
self._reqs_need_send: dict[ReqId, float] = {}
|
||||
@@ -554,12 +577,54 @@ class NixlConnectorScheduler:
|
||||
# remote prefill or aborted.
|
||||
self._reqs_not_processed: set[ReqId] = set()
|
||||
|
||||
# Gather Sliding Window sizes for each kv cache group (if any) in number of
|
||||
# blocks per KV cache group. This is used to clip the local attention window.
|
||||
sw_sizes_tokens: list[tuple[int, int]] = [
|
||||
(g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
|
||||
if isinstance(g.kv_cache_spec, SlidingWindowSpec)
|
||||
else (0, self.block_size)
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to conservatively
|
||||
# account for boundary overlap eg window isn't fully aligned with blocks.
|
||||
self.blocks_per_sw = [
|
||||
cdiv(n_tokens, block_size) + 1 if n_tokens else 0
|
||||
for n_tokens, block_size in sw_sizes_tokens
|
||||
]
|
||||
|
||||
def shutdown(self):
|
||||
self._stop_event.set()
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
self._nixl_handshake_listener_t.join()
|
||||
self._nixl_handshake_listener_t = None
|
||||
|
||||
def get_sw_clipped_blocks(self, block_ids: BlockIds) -> BlockIds:
|
||||
"""
|
||||
Clip the number of blocks to the sliding window size for each kv cache group
|
||||
that employs SWA.
|
||||
This is necessary because the KV Cache manager initially allocates blocks for
|
||||
the entire sequence length, and successively cleans up blocks that are outside
|
||||
the window prior to the `request_finished_all_groups` hook.
|
||||
"""
|
||||
if len(block_ids) == 0 or not self._is_hma_required:
|
||||
# No blocks to clip eg Full prefix cache hit or not a hybrid model.
|
||||
return block_ids
|
||||
# NOTE (NickLucche) This logic is currently handled at the connector level
|
||||
# because offloading connectors might want to receive the whole sequence even
|
||||
# for SWA groups. We will abstract this logic once the interface is more stable
|
||||
assert len(block_ids) == len(self.blocks_per_sw), (
|
||||
"Number of KV cache groups must match"
|
||||
)
|
||||
# For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged
|
||||
return tuple(
|
||||
[
|
||||
blocks[-self.blocks_per_sw[i] :]
|
||||
if self.blocks_per_sw[i] > 0
|
||||
else blocks
|
||||
for i, blocks in enumerate(block_ids)
|
||||
]
|
||||
)
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
@@ -707,12 +772,18 @@ class NixlConnectorScheduler:
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
local_block_ids = (
|
||||
blocks.get_unhashed_block_ids()
|
||||
|
||||
unhashed_local_block_ids: BlockIds = (
|
||||
blocks.get_unhashed_block_ids_all_groups()
|
||||
if num_external_tokens > 0
|
||||
else []
|
||||
else ()
|
||||
)
|
||||
# Get unhashed blocks to pull from remote.
|
||||
local_block_ids = self.get_sw_clipped_blocks(
|
||||
unhashed_local_block_ids
|
||||
)
|
||||
|
||||
# Get unhashed blocks to pull from remote. Mind that a full prefix
|
||||
# cache hit is indicated with an empty list.
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request,
|
||||
local_block_ids,
|
||||
@@ -753,9 +824,10 @@ class NixlConnectorScheduler:
|
||||
req = req_to_save
|
||||
|
||||
assert req.kv_transfer_params is not None
|
||||
clipped_block_id_groups = self.get_sw_clipped_blocks(new_block_id_groups)
|
||||
meta.add_new_req_to_save(
|
||||
request_id=req_id,
|
||||
local_block_ids=new_block_id_groups[0],
|
||||
local_block_ids=clipped_block_id_groups,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
)
|
||||
assert scheduler_output.num_scheduled_tokens is not None
|
||||
@@ -786,7 +858,7 @@ class NixlConnectorScheduler:
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
block_ids: BlockIds,
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
@@ -828,7 +900,7 @@ class NixlConnectorScheduler:
|
||||
|
||||
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||
# remove the conditional below
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
delay_free_blocks = any(len(group) > 0 for group in block_ids)
|
||||
|
||||
if delay_free_blocks:
|
||||
# Prefill request on remote. It will be read from D upon completion
|
||||
@@ -841,6 +913,11 @@ class NixlConnectorScheduler:
|
||||
self._reqs_need_send[request.request_id] = (
|
||||
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
|
||||
)
|
||||
# NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones),
|
||||
# trimming down after allocating for the whole sequence length. Empty
|
||||
# blocks are always at the start of the list.
|
||||
# Here we "unpad" blocks to send the actual remote blocks to be read.
|
||||
block_ids = self.get_sw_clipped_blocks(block_ids)
|
||||
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
@@ -857,7 +934,9 @@ class NixlConnectorScheduler:
|
||||
class NixlConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, engine_id: str, kv_cache_config: "KVCacheConfig"
|
||||
):
|
||||
if NixlWrapper is None:
|
||||
logger.error("NIXL is not available")
|
||||
raise RuntimeError("NIXL is not available")
|
||||
@@ -875,6 +954,14 @@ class NixlConnectorWorker:
|
||||
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"backends", ["UCX"]
|
||||
)
|
||||
self._is_hma_required = (
|
||||
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
|
||||
and any(
|
||||
not isinstance(g.kv_cache_spec, FullAttentionSpec)
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
# Agent.
|
||||
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
||||
@@ -1017,10 +1104,6 @@ class NixlConnectorWorker:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
# List of block window sizes for each layer for local attention
|
||||
self.block_window_per_layer: list[int | None] = []
|
||||
self.use_mla = self.model_config.use_mla
|
||||
|
||||
# Get the attention backend from the first layer
|
||||
@@ -1030,8 +1113,8 @@ class NixlConnectorWorker:
|
||||
self.backend_name = self.attn_backend.get_name()
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
self.host_buffer_kv_cache_layout = self.kv_cache_layout
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
logger.info("Detected attention backend %s", self.backend_name)
|
||||
logger.info("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
|
||||
# lazy initialized in register_kv_caches
|
||||
self.compat_hash: str | None = None
|
||||
@@ -1238,9 +1321,15 @@ class NixlConnectorWorker:
|
||||
"remote_request_id": meta.remote.request_id,
|
||||
"remote_host": meta.remote.host,
|
||||
"remote_port": meta.remote.port,
|
||||
"num_local_blocks": len(meta.local_block_ids),
|
||||
"num_remote_blocks": len(meta.remote.block_ids),
|
||||
"local_block_ids_sample": meta.local_block_ids[:10],
|
||||
"num_local_blocks": sum(
|
||||
len(group) for group in meta.local_block_ids
|
||||
),
|
||||
"num_remote_blocks": sum(
|
||||
len(group) for group in meta.remote.block_ids
|
||||
),
|
||||
"local_block_ids_sample": meta.local_block_ids[0][:10]
|
||||
if meta.local_block_ids
|
||||
else [],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1301,8 +1390,10 @@ class NixlConnectorWorker:
|
||||
error=e,
|
||||
meta=meta,
|
||||
)
|
||||
if req_meta := self._recving_metadata.get(req_id):
|
||||
self._invalid_block_ids.update(req_meta.local_block_ids)
|
||||
if (
|
||||
req_meta := self._recving_metadata.get(req_id)
|
||||
) and not self._is_hma_required:
|
||||
self._invalid_block_ids.update(req_meta.local_block_ids[0])
|
||||
self._failed_recv_reqs.add(req_id)
|
||||
|
||||
fut.add_done_callback(request_ready)
|
||||
@@ -1370,6 +1461,10 @@ class NixlConnectorWorker:
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
if base_addr in seen_base_addresses:
|
||||
# NOTE (NickLucche) HMA employs memory pooling to share tensors
|
||||
# across groups. This results in skipping all tensors but the ones
|
||||
# pointed to by group0. Also, generally we will have more blocks
|
||||
# per tensor but fewer regions.
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
@@ -1457,28 +1552,6 @@ class NixlConnectorWorker:
|
||||
self.register_local_xfer_handler(self.block_size)
|
||||
)
|
||||
|
||||
# TODO(mgoin): Hybrid memory allocator is currently disabled for
|
||||
# models with local attention (Llama 4). Can remove this once enabled.
|
||||
if self.model_config.hf_config.model_type == "llama4":
|
||||
from transformers import Llama4TextConfig
|
||||
|
||||
assert isinstance(self.model_config.hf_text_config, Llama4TextConfig)
|
||||
llama4_config = self.model_config.hf_text_config
|
||||
no_rope_layers = llama4_config.no_rope_layers
|
||||
chunk_size = llama4_config.attention_chunk_size
|
||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||
for layer_idx in range(self.num_layers):
|
||||
# no_rope_layers[layer_idx] == 0 means NoPE (global)
|
||||
# Any other value means RoPE (local chunked)
|
||||
is_local_attention = no_rope_layers[layer_idx] != 0
|
||||
block_window = chunk_block_size if is_local_attention else None
|
||||
self.block_window_per_layer.append(block_window)
|
||||
logger.debug(
|
||||
"Llama 4 block window per layer mapping: %s",
|
||||
self.block_window_per_layer,
|
||||
)
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
agent_metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
@@ -1767,6 +1840,11 @@ class NixlConnectorWorker:
|
||||
# Num kv_heads > tp_size and P TP > D TP case, not supported
|
||||
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
|
||||
|
||||
if self._is_hma_required:
|
||||
assert block_size_ratio == 1, (
|
||||
"HMA does not support different remote block size yet"
|
||||
)
|
||||
|
||||
kv_cache_layout = (
|
||||
self.kv_cache_layout
|
||||
if not self.use_host_buffer
|
||||
@@ -1781,6 +1859,9 @@ class NixlConnectorWorker:
|
||||
"Remote is HND and local is NHD, enabled additional permute "
|
||||
"on local device KV."
|
||||
)
|
||||
assert not self._is_hma_required, (
|
||||
"HMA does not support block size post processing"
|
||||
)
|
||||
self.enable_permute_local_kv = True
|
||||
else:
|
||||
raise RuntimeError(
|
||||
@@ -1836,13 +1917,15 @@ class NixlConnectorWorker:
|
||||
assert self.copy_blocks is not None
|
||||
|
||||
local_block_ids = meta.local_physical_block_ids
|
||||
self.copy_blocks(
|
||||
self.host_xfer_buffers,
|
||||
self.device_kv_caches,
|
||||
local_block_ids,
|
||||
local_block_ids,
|
||||
"h2d",
|
||||
)
|
||||
# TODO (NickLucche) D2H<>H2D ops could benefit from coalescing io across groups
|
||||
for group_block_ids in local_block_ids:
|
||||
self.copy_blocks(
|
||||
self.host_xfer_buffers,
|
||||
self.device_kv_caches,
|
||||
group_block_ids,
|
||||
group_block_ids,
|
||||
"h2d",
|
||||
)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"synced recved kv of request[%s] to device kv buffer,"
|
||||
@@ -1868,13 +1951,14 @@ class NixlConnectorWorker:
|
||||
",".join(map(str, meta.local_physical_block_ids)),
|
||||
)
|
||||
# blocking
|
||||
self.copy_blocks(
|
||||
self.device_kv_caches,
|
||||
self.host_xfer_buffers,
|
||||
meta.local_physical_block_ids,
|
||||
meta.local_physical_block_ids,
|
||||
"d2h",
|
||||
)
|
||||
for group_block_ids in meta.local_physical_block_ids:
|
||||
self.copy_blocks(
|
||||
self.device_kv_caches,
|
||||
self.host_xfer_buffers,
|
||||
group_block_ids,
|
||||
group_block_ids,
|
||||
"d2h",
|
||||
)
|
||||
|
||||
def post_process_device_kv_on_receive(
|
||||
self,
|
||||
@@ -1973,8 +2057,9 @@ class NixlConnectorWorker:
|
||||
if not self.use_mla and (
|
||||
block_size_ratio > 1 or self.enable_permute_local_kv
|
||||
):
|
||||
assert not self._is_hma_required
|
||||
block_ids_for_blocksize_post_process[block_size_ratio].append(
|
||||
meta.local_physical_block_ids
|
||||
meta.local_physical_block_ids[0]
|
||||
)
|
||||
for (
|
||||
block_size_ratio,
|
||||
@@ -2106,8 +2191,9 @@ class NixlConnectorWorker:
|
||||
handle: The transfer handle.
|
||||
"""
|
||||
# Use .get() here as the metadata cleanup is handled by get_finished()
|
||||
if meta := self._recving_metadata.get(req_id):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
# TODO (NickLucche) handle failed transfer for HMA.
|
||||
if (meta := self._recving_metadata.get(req_id)) and not self._is_hma_required:
|
||||
self._invalid_block_ids.update(meta.local_block_ids[0])
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self.xfer_stats.record_failed_transfer()
|
||||
|
||||
@@ -2230,8 +2316,8 @@ class NixlConnectorWorker:
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
local_block_ids: BlockIds,
|
||||
remote_block_ids: BlockIds,
|
||||
dst_engine_id: str,
|
||||
request_id: str,
|
||||
remote_request_id: str,
|
||||
@@ -2246,22 +2332,30 @@ class NixlConnectorWorker:
|
||||
assert self.kv_topo is not None
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
|
||||
if block_size_ratio > 1:
|
||||
local_block_ids = self.get_mapped_blocks(
|
||||
np.asarray(local_block_ids), block_size_ratio
|
||||
)
|
||||
if len(local_block_ids) > len(remote_block_ids):
|
||||
# TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups.
|
||||
assert not self._is_hma_required
|
||||
local_block_ids0 = local_block_ids[0] if local_block_ids else []
|
||||
remote_block_ids0 = remote_block_ids[0]
|
||||
local_block_ids_mapped = self.get_mapped_blocks(
|
||||
np.asarray(local_block_ids0), block_size_ratio
|
||||
).tolist()
|
||||
if len(local_block_ids_mapped) > len(remote_block_ids0):
|
||||
# NOTE:
|
||||
# get_mapped_blocks will always expand block_ids for n times.
|
||||
# ex:
|
||||
# prefill block_ids with block_size as 4:
|
||||
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
# Local decode block_ids with block_size as 16: [1, 2, 3]
|
||||
# expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to
|
||||
# expanded decode block_ids with get_mapped_blocks from [1, 2, 3] to
|
||||
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||||
# Then we clip local to align with prefill
|
||||
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
|
||||
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
local_block_ids = local_block_ids[: len(remote_block_ids)]
|
||||
local_block_ids_mapped = local_block_ids_mapped[
|
||||
: len(remote_block_ids0)
|
||||
]
|
||||
local_block_ids = [local_block_ids_mapped] if local_block_ids_mapped else []
|
||||
remote_block_ids = [remote_block_ids0]
|
||||
# NOTE(rob): having the staging blocks be on the READER side is
|
||||
# not going to work well (since we will have to call rearrange tensors).
|
||||
# after we detect the txn is complete (which means we cannot make the
|
||||
@@ -2269,8 +2363,7 @@ class NixlConnectorWorker:
|
||||
# then we will need to have the staging blocks on the remote side.
|
||||
|
||||
# NOTE(rob): according to nvidia the staging blocks are used to
|
||||
# saturate IB with heterogeneous TP sizes. We should remove the staging
|
||||
# blocks until we are ready.
|
||||
# saturate IB with heterogeneous TP sizes.
|
||||
|
||||
# Number of D TP workers that will read from dst P. Propagate info
|
||||
# on notification so that dst worker can wait before freeing blocks.
|
||||
@@ -2278,8 +2371,8 @@ class NixlConnectorWorker:
|
||||
|
||||
# Full prefix cache hit: do not need to read remote blocks,
|
||||
# just notify P worker that we have the blocks we need.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
if len(local_block_ids) == 0:
|
||||
# A full prefix cache hit is indicated with an empty list.
|
||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||
try:
|
||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||
@@ -2297,66 +2390,34 @@ class NixlConnectorWorker:
|
||||
self.xfer_stats.record_failed_notification()
|
||||
return
|
||||
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
assert num_local_blocks <= num_remote_blocks
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
assert (
|
||||
len(remote_block_ids)
|
||||
== len(local_block_ids)
|
||||
== len(self.kv_cache_config.kv_cache_groups)
|
||||
)
|
||||
remote_block_ids = list(remote_block_ids)
|
||||
for i, remote_group in enumerate(remote_block_ids):
|
||||
num_remote_blocks = len(remote_group)
|
||||
num_local_blocks = len(local_block_ids[i])
|
||||
assert num_local_blocks <= num_remote_blocks
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
remote_block_ids[i] = remote_group[-num_local_blocks:]
|
||||
|
||||
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
||||
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
||||
# workers will issue xfers to parts of the P worker remote kv caches.
|
||||
|
||||
# Get descs ids.
|
||||
local_block_descs_ids: np.ndarray
|
||||
remote_block_descs_ids: np.ndarray
|
||||
|
||||
if not self.block_window_per_layer:
|
||||
# Default case: assume global attention
|
||||
remote_block_descs_ids = self._get_block_descs_ids(
|
||||
dst_engine_id,
|
||||
remote_block_ids,
|
||||
)
|
||||
local_block_descs_ids = self._get_block_descs_ids(
|
||||
self.engine_id,
|
||||
local_block_ids,
|
||||
block_size_ratio=block_size_ratio,
|
||||
)
|
||||
else:
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
local_descs_list = []
|
||||
remote_descs_list = []
|
||||
for layer_idx, block_window in enumerate(self.block_window_per_layer):
|
||||
# For each layer:
|
||||
if block_window is None:
|
||||
# If not chunked, we just use the
|
||||
# full block lists (global attention)
|
||||
layer_local_block_ids = local_block_ids
|
||||
layer_remote_block_ids = remote_block_ids
|
||||
else:
|
||||
# If chunked, get the last block_window blocks
|
||||
layer_local_block_ids = local_block_ids[-block_window:]
|
||||
layer_remote_block_ids = remote_block_ids[-block_window:]
|
||||
|
||||
# Get descs ids for the layer.
|
||||
layer_local_desc_ids = self._get_block_descs_ids(
|
||||
self.engine_id,
|
||||
layer_local_block_ids,
|
||||
layer_idx,
|
||||
block_size_ratio=block_size_ratio,
|
||||
)
|
||||
layer_remote_desc_ids = self._get_block_descs_ids(
|
||||
dst_engine_id,
|
||||
layer_remote_block_ids,
|
||||
layer_idx,
|
||||
)
|
||||
|
||||
local_descs_list.append(layer_local_desc_ids)
|
||||
remote_descs_list.append(layer_remote_desc_ids)
|
||||
|
||||
local_block_descs_ids = np.concatenate(local_descs_list)
|
||||
remote_block_descs_ids = np.concatenate(remote_descs_list)
|
||||
remote_block_descs_ids = self._get_block_descs_ids(
|
||||
dst_engine_id,
|
||||
remote_block_ids,
|
||||
)
|
||||
local_block_descs_ids = self._get_block_descs_ids(
|
||||
self.engine_id,
|
||||
local_block_ids,
|
||||
block_size_ratio=block_size_ratio,
|
||||
)
|
||||
|
||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||
|
||||
@@ -2387,14 +2448,18 @@ class NixlConnectorWorker:
|
||||
dst_engine_id=dst_engine_id,
|
||||
remote_rank=remote_rank,
|
||||
)
|
||||
if meta := self._recving_metadata.get(request_id):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
if (
|
||||
meta := self._recving_metadata.get(request_id)
|
||||
) and not self._is_hma_required:
|
||||
self._invalid_block_ids.update(meta.local_block_ids[0])
|
||||
self.xfer_stats.record_failed_transfer()
|
||||
if handle is not None:
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self._failed_recv_reqs.add(request_id)
|
||||
|
||||
def get_mapped_blocks(self, block_ids, block_size_ratio):
|
||||
def get_mapped_blocks(
|
||||
self, block_ids: np.ndarray, block_size_ratio: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Calculates the new set of block IDs by mapping every element
|
||||
in the (potentially sparse) input array.
|
||||
@@ -2416,41 +2481,32 @@ class NixlConnectorWorker:
|
||||
def _get_block_descs_ids(
|
||||
self,
|
||||
engine_id: str,
|
||||
block_ids: list[int],
|
||||
layer_idx: int | None = None,
|
||||
block_ids: BlockIds,
|
||||
block_size_ratio: float | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get the descs ids for a set of block ids.
|
||||
If layer_idx is provided, we use the region_ids for the given layer.
|
||||
Otherwise, we use all regions.
|
||||
When HMA is enabled number of descriptors across kv cache groups might differ.
|
||||
A single flattened array is returned for all groups anyway.
|
||||
"""
|
||||
if layer_idx is None:
|
||||
region_ids = np.arange(self.num_regions)
|
||||
else:
|
||||
assert layer_idx < self.num_layers
|
||||
if self.num_layers < self.num_regions:
|
||||
# If we have more regions than layers, we assume that
|
||||
# the regions are organized as [K0, V0, K1, V1, ...]
|
||||
# and we select K_i and V_i
|
||||
assert 2 * self.num_layers == self.num_regions
|
||||
region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2)
|
||||
else:
|
||||
# Otherwise, we assume we have MLA and select i-th layer
|
||||
assert self.num_layers == self.num_regions
|
||||
region_ids = np.arange(layer_idx, layer_idx + 1)
|
||||
|
||||
region_ids = np.arange(self.num_regions)
|
||||
# NOTE (NickLucche) With HMA, every kv group has the same number of layers and
|
||||
# layers from different groups share the same kv tensor.
|
||||
# eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions,
|
||||
# same for [3], but group0-group1 blocks will always differ (different areas).
|
||||
# Therefore we can just flatten the block_ids and compute the descs ids for all
|
||||
# groups at once.
|
||||
num_blocks = self.dst_num_blocks[engine_id]
|
||||
if block_size_ratio is not None:
|
||||
num_blocks = int(num_blocks * block_size_ratio)
|
||||
|
||||
# Compute the desc ids for each block.
|
||||
region_ids = region_ids[:, None]
|
||||
block_ids = np.array(block_ids)[None, :]
|
||||
block_ids = np.concatenate(block_ids)[None, :]
|
||||
descs_ids = region_ids * num_blocks + block_ids
|
||||
return descs_ids.flatten()
|
||||
|
||||
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
|
||||
def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds:
|
||||
"""
|
||||
Convert logical block ids to kernel physical block ids.
|
||||
This is required when the logical block size (the one set by the user)
|
||||
@@ -2459,13 +2515,17 @@ class NixlConnectorWorker:
|
||||
if self._physical_blocks_per_logical_kv_block == 1:
|
||||
# Noop when physical and logical block sizes are the same
|
||||
return block_ids
|
||||
block_ids_np = np.array(block_ids)
|
||||
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
|
||||
1, -1
|
||||
)
|
||||
return BlockTable.map_to_kernel_blocks(
|
||||
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
|
||||
).tolist()
|
||||
return [
|
||||
BlockTable.map_to_kernel_blocks(
|
||||
np.array(group),
|
||||
self._physical_blocks_per_logical_kv_block,
|
||||
block_arange,
|
||||
).tolist()
|
||||
for group in block_ids
|
||||
]
|
||||
|
||||
def get_backend_aware_kv_block_len(self, layer_idx: int) -> int:
|
||||
"""
|
||||
|
||||
@@ -84,6 +84,18 @@ class KVCacheBlocks:
|
||||
assert len(self.blocks) == 1, "Only one group is supported"
|
||||
return [block.block_id for block in self.blocks[0] if block.block_hash is None]
|
||||
|
||||
def get_unhashed_block_ids_all_groups(self) -> list[list[int]]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
# Skip padding blocks.
|
||||
return [
|
||||
[
|
||||
block.block_id
|
||||
for block in group
|
||||
if block.block_hash is None and not block.is_null
|
||||
]
|
||||
for group in self.blocks
|
||||
]
|
||||
|
||||
def new_empty(self) -> "KVCacheBlocks":
|
||||
"""
|
||||
Creates a new KVCacheBlocks instance with no blocks.
|
||||
|
||||
Reference in New Issue
Block a user