diff --git a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh index abdf88ad6..c35f4bfe8 100755 --- a/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh @@ -12,6 +12,7 @@ tp_configs=( "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" + "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192" # SW model ) dp_ep_configs=( "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) @@ -26,6 +27,14 @@ else configs=("${tp_configs[@]}") fi +if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then + # Append ENABLE_HMA_FLAG=1 to each config in the selected array + echo "ENABLE_HMA_FLAG is set, appending ENABLE_HMA_FLAG=1 to each config" + for i in "${!configs[@]}"; do + configs[$i]="ENABLE_HMA_FLAG=1 ${configs[$i]}" + done +fi + run_tests() { local label=$1 local extra_args=$2 diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 673236625..fe9524960 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -5,6 +5,12 @@ set -xe KV_BUFFER_DEVICE="cuda" # Default to cuda ATTENTION_BACKEND="" # Default to empty (use vllm default) CROSS_LAYERS_BLOCKS="False" +ENABLE_HMA_VAR="" # Default to empty (HMA disabled by default for kv connector) +# Check for ENABLE_HMA_FLAG environment variable +if [[ -n "${ENABLE_HMA_FLAG:-}" ]]; then + ENABLE_HMA_VAR="--no-disable-hybrid-kv-cache-manager" +fi + while [[ $# -gt 0 ]]; do case $1 in --kv_buffer_device) @@ -31,6 +37,12 @@ echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" if [[ -n "$ATTENTION_BACKEND" ]]; then echo "Using attention backend: $ATTENTION_BACKEND" fi +if [[ -n "$ENABLE_HMA_VAR" ]]; then + echo "HMA (Hybrid KV Cache Manager) enabled" +fi +if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then + echo "vLLM serve extra args: $VLLM_SERVE_EXTRA_ARGS" +fi DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then @@ -70,6 +82,8 @@ DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128} DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128} +# Comma-separated extra args for vllm serve (e.g. --max-model-len,2048) +VLLM_SERVE_EXTRA_ARGS=${VLLM_SERVE_EXTRA_ARGS:-} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) @@ -151,14 +165,24 @@ run_tests_for_model() { --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '$KV_CONFIG'" + if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then + IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS" + for arg in "${extra_args[@]}"; do + BASE_CMD="${BASE_CMD} $arg" + done + fi # Add attention backend config if specified if [[ -n "$ATTENTION_BACKEND" ]]; then BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND" fi + # Add HMA flag if specified + if [[ -n "$ENABLE_HMA_VAR" ]]; then + BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR" + fi + FULL_CMD="$BASE_CMD" - eval "$FULL_CMD &" # Store host and port for proxy configuration @@ -193,12 +217,23 @@ run_tests_for_model() { --block-size ${DECODE_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --kv-transfer-config '$KV_CONFIG'" + if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then + IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS" + for arg in "${extra_args[@]}"; do + BASE_CMD="${BASE_CMD} $arg" + done + fi # Add attention backend config if specified if [[ -n "$ATTENTION_BACKEND" ]]; then BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND" fi + # Add HMA flag if specified + if [[ -n "$ENABLE_HMA_VAR" ]]; then + BASE_CMD="${BASE_CMD} $ENABLE_HMA_VAR" + fi + # DP-EP attention mode if [[ -z "$DP_EP" ]]; then BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE" diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index a70f4caeb..674e65c25 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -17,6 +17,7 @@ EXPECTED_VALUES = { "deepseek-ai/deepseek-vl2-small": 0.59, "deepseek-ai/deepseek-vl2-tiny": 0.19, "deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65, + "google/gemma-3-4b-it": 0.74, } SIMPLE_PROMPT = ( diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 15ca74db3..d59a9cbdd 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -59,7 +59,12 @@ from vllm.v1.request import RequestStatus from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.utils import AttentionGroup -from .utils import create_request, create_scheduler, create_vllm_config +from .utils import ( + create_request, + create_scheduler, + create_vllm_config, + make_kv_cache_config, +) @pytest.fixture(scope="module", autouse=True) @@ -263,7 +268,7 @@ def test_basic_interface(): req_meta = kv_connector_metadata.reqs_to_recv[request_id] for block_id, block in zip( - req_meta.local_block_ids, + req_meta.local_block_ids[0], scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ request_id ], @@ -327,7 +332,9 @@ def test_kv_transfer_handshake(dist_init): # Prefill connector will register KV cache to populate proper handshake # metadata. - prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + prefill_connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 ) @@ -367,13 +374,17 @@ def test_kv_transfer_handshake(dist_init): do_remote_decode=True, ) request.status = RequestStatus.FINISHED_LENGTH_CAPPED - delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished( - request, [0, 1, 2] + delay, kv_connector_metadata = ( + scheduler.get_kv_connector().request_finished_all_groups( + request, ([0, 1, 2],) + ) ) assert delay # Decode connector will be able to create handshake with the prefill connector. - decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + decode_connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) decode_connector.register_kv_caches(kv_caches) # Here we are testing the retrieval of NIXLAgentMetadata. @@ -404,9 +415,16 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" def __init__( - self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs + self, + *args, + hand_shake_latency: float = 1.8, + kv_cache_layout="HND", + kv_cache_config=None, + **kwargs, ): - super().__init__(*args, **kwargs) + if kv_cache_config is None: + kv_cache_config = make_kv_cache_config(block_size=16) + super().__init__(*args, kv_cache_config=kv_cache_config, **kwargs) self._hand_shake_latency = hand_shake_latency self.kv_cache_layout = kv_cache_layout # Mock register_kv_caches attribute needed for tests that do not call it. @@ -507,7 +525,9 @@ class TestNixlHandshake: request_id = "req_id" # Test worker role in decode server. - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 ) @@ -528,13 +548,15 @@ class TestNixlHandshake: num_xfers -= 1 metadata.add_new_req_to_recv( request_id=request_id, - local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3], + local_block_ids=([num_xfers + 1, num_xfers + 2, num_xfers + 3],), kv_transfer_params={ - "remote_block_ids": [ - num_xfers + 4, - num_xfers + 5, - num_xfers + 6, - ], + "remote_block_ids": ( + [ + num_xfers + 4, + num_xfers + 5, + num_xfers + 6, + ], + ), "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", @@ -594,16 +616,18 @@ class TestNixlHandshake: vllm_config.parallel_config.tensor_parallel_size = decode_tp_size # Test worker role in decode server. - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id ) metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( request_id="id", - local_block_ids=[1, 2, 3], + local_block_ids=([1, 2, 3],), kv_transfer_params={ - "remote_block_ids": [4, 5, 6], + "remote_block_ids": ([4, 5, 6],), "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_request_id": "prefill-id", "remote_host": "localhost", @@ -652,7 +676,9 @@ class TestNixlHandshake: local_tp_size = 1 vllm_config.parallel_config.tensor_parallel_size = local_tp_size - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 ) @@ -717,8 +743,12 @@ class TestNixlHandshake: p_tp_size = 2 # Build two separate connectors/workers to emulate P TP=2 ranks. - conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER) - conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER) + conn_p0 = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + conn_p1 = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) conn_p0.connector_worker = FakeNixlConnectorWorker( vllm_config, conn_p0.engine_id, hand_shake_latency=0 ) @@ -815,7 +845,9 @@ class TestNixlHandshake: vllm_config = create_vllm_config() # Test worker role in decode server. - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id ) @@ -827,9 +859,9 @@ class TestNixlHandshake: for i in range(total_reqs): metadata.add_new_req_to_recv( request_id=f"id_{i}", - local_block_ids=[1, 2, 3], + local_block_ids=([1, 2, 3],), kv_transfer_params={ - "remote_block_ids": [4, 5, 6], + "remote_block_ids": ([4, 5, 6],), "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_request_id": f"prefill-id-{i}", "remote_host": "localhost", @@ -884,7 +916,9 @@ class TestNixlHandshake: return_value=2, ): # Initialize connector and worker (with fake NIXL wrapper) - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 ) @@ -934,7 +968,9 @@ class TestNixlHandshake: return_value=2, ): # Initialize connector and worker (with fake NIXL wrapper) - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id, @@ -979,7 +1015,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init): vllm_config = create_vllm_config() # Test worker role in decode server. - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 ) @@ -993,9 +1031,9 @@ def test_kv_connector_stats(default_vllm_config, dist_init): metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( request_id=request_id, - local_block_ids=[1, 2, 3], + local_block_ids=([1, 2, 3],), kv_transfer_params={ - "remote_block_ids": [4, 5, 6], + "remote_block_ids": ([4, 5, 6],), "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", @@ -1448,7 +1486,9 @@ def test_register_kv_caches( mock_get_attn_backend.return_value = backend_cls # Create connector - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 ) @@ -1676,7 +1716,9 @@ def test_kv_buffer_to_nixl_memory_types( ), ): # noqa: E501 # Create connector and replace its worker with a fake one for isolation - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) # Verify get_reg_descs was called with the correct memory_type assert connector.connector_worker.kv_buffer_device == kv_buffer_device @@ -1692,9 +1734,15 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init): vllm_config = create_vllm_config() scheduler = NixlConnectorScheduler( - vllm_config, vllm_config.kv_transfer_config.engine_id + vllm_config, + vllm_config.kv_transfer_config.engine_id, + make_kv_cache_config(block_size=16), + ) + worker = NixlConnectorWorker( + vllm_config, + vllm_config.kv_transfer_config.engine_id, + make_kv_cache_config(block_size=16), ) - worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id) nixl_wrapper = worker.nixl_wrapper with ( @@ -1756,7 +1804,9 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_ scheduler = create_scheduler(vllm_config) # KVConnector Worker in P - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 ) @@ -1875,12 +1925,14 @@ class FailingNixlWrapper(FakeNixlWrapper): ("transfer_exception", {"fail_transfer_exception": True}, True), ], ) +@pytest.mark.parametrize("enable_hma", [False, True]) def test_transfer_failure_logging( default_vllm_config, dist_init, failure_type, wrapper_config, needs_get_finished, + enable_hma, ): """Test that transfer failures are logged with structured context. @@ -1897,9 +1949,16 @@ def test_transfer_failure_logging( vllm_config = create_vllm_config() - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, + KVConnectorRole.WORKER, + make_kv_cache_config(block_size=16, hma_enabled=enable_hma), + ) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0.0 + vllm_config, + connector.engine_id, + hand_shake_latency=0.0, + kv_cache_config=connector._kv_cache_config, ) # Configure FailingNixlWrapper to fail in the specified way @@ -1910,8 +1969,17 @@ def test_transfer_failure_logging( # For notification_failed, we need empty local blocks # (full cache hit path to trigger send_notif) - local_blocks = [] if failure_type == "notification_failed" else [10, 11, 12] - remote_blocks = [20, 21, 22] + local_blocks: tuple[()] | tuple[list[int], ...] + if enable_hma: + # HMA enabled: multiple groups (FA + SW) + local_blocks = ( + () if failure_type == "notification_failed" else ([10, 11, 12], [13, 14]) + ) + remote_blocks = [[20, 21, 22], [23, 24]] + else: + # HMA disabled: single group + local_blocks = () if failure_type == "notification_failed" else ([10, 11, 12],) + remote_blocks = [[20, 21, 22]] metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( @@ -2007,7 +2075,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): """Test that handshake failures mark blocks invalid and return via get_finished.""" vllm_config = create_vllm_config() - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0.1 ) @@ -2017,9 +2087,9 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( request_id=request_id, - local_block_ids=[1, 2, 3], + local_block_ids=([1, 2, 3],), kv_transfer_params={ - "remote_block_ids": [4, 5, 6], + "remote_block_ids": ([4, 5, 6],), "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", @@ -2058,7 +2128,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) and return via get_finished.""" vllm_config = create_vllm_config() - connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 ) @@ -2068,9 +2140,9 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) metadata = NixlConnectorMetadata() metadata.add_new_req_to_recv( request_id=request_id, - local_block_ids=[7, 8, 9], + local_block_ids=([7, 8, 9],), kv_transfer_params={ - "remote_block_ids": [10, 11, 12], + "remote_block_ids": ([10, 11, 12],), "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", @@ -2154,7 +2226,9 @@ def test_compatibility_hash_validation( "enforce_handshake_compat": enforce_handshake_compat }, ) - decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) + decode_connector = NixlConnector( + local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) decode_worker = decode_connector.connector_worker kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape( num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 @@ -2267,7 +2341,9 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) model="facebook/opt-125m", block_size=16, ) - decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) + decode_connector = NixlConnector( + local_vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) decode_worker = decode_connector.connector_worker backend = get_current_attn_backend(local_vllm_config) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py new file mode 100644 index 000000000..636d51402 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA.""" + +from unittest.mock import patch + +import pytest + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig +from vllm.v1.core.single_type_kv_cache_manager import ( + FullAttentionManager, + SlidingWindowManager, +) + +from .utils import ( + create_vllm_config, + make_kv_cache_config, +) + + +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "hma_enabled,expected_sw_sizes", + [ + # HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128) + (True, [0, 128 + 1]), + # HMA disabled: only FullAttentionSpec (0) + (False, [0]), + ], +) +@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") +def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes): + """Test sw_sizes is correctly computed based on HMA enabled/disabled.""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorScheduler, + ) + + mock_platform.device_type = "cpu" + + block_size = 16 + vllm_config = create_vllm_config(block_size=block_size) + # SW 2048 tokens=>128 blocks + kv_cache_config = make_kv_cache_config( + block_size=block_size, hma_enabled=hma_enabled, sw_size=2048 + ) + + scheduler = NixlConnectorScheduler( + vllm_config=vllm_config, + engine_id="test-engine", + kv_cache_config=kv_cache_config, + ) + # in number of blocks + assert scheduler.blocks_per_sw == expected_sw_sizes, ( + f"Expected sw_sizes={expected_sw_sizes}, got {scheduler.blocks_per_sw}" + ) + + +@pytest.mark.cpu_test +def test_logical_to_kernel_block_ids_with_hma(): + """Test _logical_to_kernel_block_ids expands blocks when HMA is enabled. + + When HMA is enabled, the logical block size may differ from the kernel + block size. Each logical block maps to multiple kernel blocks. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorWorker, + ) + + # Create a mock worker with just the required attributes + # (use __new__ to skip __init__) + worker = object.__new__(NixlConnectorWorker) + + # Simulate HMA scenario: logical block size = 32, kernel block size = 16 + # So each logical block maps to 2 kernel blocks eg [0]->[0,1] + worker._physical_blocks_per_logical_kv_block = 2 + + # Test conversion: FA + SW group + logical_block_ids = [[0, 1, 2], [3, 4]] + kernel_block_ids = worker._logical_to_kernel_block_ids(logical_block_ids) + + expected_kernel_block_ids = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]] + assert kernel_block_ids == expected_kernel_block_ids, ( + f"Expected {expected_kernel_block_ids}, got {kernel_block_ids}" + ) + + +@pytest.mark.parametrize("model_name, sw_size", [("google/gemma-3-1b-it", 512)]) +def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size): + """Test that a prefill instance returns fewer "remote blocks" for the SWA groups + when sequence exceeds the sliding window. + """ + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + block_size = 16 + llm_kwargs = { + "model": model_name, + "enforce_eager": True, + "gpu_memory_utilization": 0.5, + "kv_transfer_config": kv_transfer_config, + "max_model_len": 2048, + # NOTE: Make sure HMA is enabled + "disable_hybrid_kv_cache_manager": False, + "max_num_batched_tokens": 1024, + "enable_prefix_caching": False, + "block_size": block_size, + } + + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + def run_hma_test(llm: LLM): + remote_prefill_opts = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + } + # Simulate sidecar request + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=1, + extra_args={"kv_transfer_params": remote_prefill_opts}, + ) + scheduler = llm.llm_engine.engine_core.engine_core.scheduler + kv_managers = scheduler.kv_cache_manager.coordinator.single_type_managers + # HMA enabled with FA + SWA groups + assert len(kv_managers) > 2 + for kv_manager in kv_managers: + assert isinstance(kv_manager, (SlidingWindowManager, FullAttentionManager)) + req_to_blocks = kv_managers[0].req_to_blocks + assert len(req_to_blocks) == 0 + + # Process some request with length exceeding the sliding window + outputs = llm.generate(["hi" * 1401], sampling_params) + kv_params = outputs[0].kv_transfer_params + + # +1 to account for overlapping window across blocks. + expected_num_remote_blocks = sw_size // block_size + 1 + remote_block_ids = kv_params["remote_block_ids"] + assert ( + len(remote_block_ids[0]) + == expected_num_remote_blocks + < len(remote_block_ids[-1]) + ) + for group_block_ids in remote_block_ids[:-1]: + assert len(group_block_ids) == expected_num_remote_blocks + + def run_test_and_cleanup(): + llm = LLM(**llm_kwargs) + try: + run_hma_test(llm) + finally: + llm.llm_engine.engine_core.shutdown() + + run_test_and_cleanup() + + +@pytest.mark.cpu_test +def test_nixl_metadata_hma_block_ids_structure(): + """ + Test that NixlConnectorMetadata correctly stores block IDs for multiple + KV cache groups when HMA is enabled. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorMetadata, + ) + + metadata = NixlConnectorMetadata() + + # Add request with block IDs for 2 groups (FA + SW) + fa_blocks = [0, 1, 2, 3, 4, 5, 6, 7] # 8 blocks for FA + sw_blocks = [8, 9, 10, 11] # 4 blocks for SW (clipped) + + metadata.add_new_req_to_recv( + request_id="test-req-hma", + local_block_ids=(fa_blocks, sw_blocks), + kv_transfer_params={ + "remote_block_ids": ([10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21]), + "remote_engine_id": "remote-engine", + "remote_request_id": "prefill-test-req-hma", + "remote_host": "localhost", + "remote_port": 1234, + "tp_size": 1, + }, + ) + + assert "test-req-hma" in metadata.reqs_to_recv + req_meta = metadata.reqs_to_recv["test-req-hma"] + + # Verify local block IDs structure + assert len(req_meta.local_block_ids) == 2 + assert list(req_meta.local_block_ids[0]) == fa_blocks + assert list(req_meta.local_block_ids[1]) == sw_blocks + + # Verify remote block IDs structure + assert req_meta.remote is not None + assert len(req_meta.remote.block_ids) == 2 + assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17] + assert list(req_meta.remote.block_ids[1]) == [18, 19, 20, 21] diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index b2ec2ddfb..b656e0809 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -208,7 +208,9 @@ def test_prefix_cache_lifecycle(): # Ensure we send all block ids, including the partial blocks, # even if there is a cache hit. - assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1) + # remote_block_ids is BlockIds (tuple of lists); sum block counts across groups. + num_remote_blocks = sum(len(g) for g in kv_transfer_params["remote_block_ids"]) + assert num_remote_blocks == (NUM_EXTERNAL_FULL_BLOCKS + 1) # STEP (2): Ensure it is freed. scheduler_output = scheduler.schedule() diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 7539da3e9..d26729981 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -36,6 +36,7 @@ from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + SlidingWindowSpec, ) from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request @@ -142,24 +143,26 @@ def create_vllm_config( def create_scheduler( vllm_config: VllmConfig, num_blocks: int = 10000, + kv_cache_config: KVCacheConfig | None = None, ) -> Scheduler: """Initialize Scheduler For Testing.""" block_size = vllm_config.cache_config.block_size - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, # A large number of blocks to hold all requests - kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec( - ["layer"], - FullAttentionSpec( - block_size=block_size, - num_kv_heads=1, - head_size=1, - dtype=torch.float32, - ), - ) - ], - ) + if kv_cache_config is None: + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + ) + ], + ) vllm_config.cache_config.num_gpu_blocks = num_blocks return Scheduler( vllm_config=vllm_config, @@ -412,3 +415,38 @@ KVConnectorFactory.register_connector( KVConnectorFactory.register_connector( "MockKVConnector", __name__, MockKVConnector.__name__ ) + + +def make_kv_cache_config( + block_size: int, + hma_enabled: bool = False, + sw_size: int = 128, + num_blocks: int = 100, +) -> KVCacheConfig: + kv_cache_groups = [ + KVCacheGroupSpec( + ["layer0", "layer2"], + FullAttentionSpec( + block_size=block_size, + num_kv_heads=4, + head_size=16, + dtype=torch.float16, + ), + ) + ] + if hma_enabled: + kv_cache_groups.append( + KVCacheGroupSpec( + ["layer1", "layer3"], + SlidingWindowSpec( + block_size=block_size, + num_kv_heads=4, + head_size=16, + dtype=torch.float16, + sliding_window=sw_size, + ), + ) + ) + return KVCacheConfig( + num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index fb6bbf7b5..eb93ea324 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -24,6 +24,9 @@ if TYPE_CHECKING: logger = init_logger(__name__) EngineId = str +# block ids as returned by the hybrid KV cache manager. list[list[int]] are allow +# mutability and are for connector internal use only. +BlockIds = tuple[list[int], ...] | list[list[int]] def get_kv_connector_cache_layout(): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c5a5b0450..fa0dd6f67 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -3,7 +3,6 @@ import contextlib import copy import logging -import math import os import queue import sys @@ -24,6 +23,7 @@ import zmq from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( + BlockIds, EngineId, TpKVTopology, get_current_attn_backend, @@ -38,6 +38,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorHandshakeMetadata, KVConnectorMetadata, KVConnectorRole, + SupportsHMA, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, @@ -53,10 +54,12 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, SlidingWindowSpec from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: @@ -205,6 +208,7 @@ def compute_nixl_compatibility_hash( model_config = vllm_config.model_config cache_config = vllm_config.cache_config + is_hma_enabled = not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager factors = { # Version compatibility @@ -220,6 +224,7 @@ def compute_nixl_compatibility_hash( "attn_backend_name": attn_backend_name, "cache_dtype": str(cache_config.cache_dtype), "cross_layers_blocks": cross_layers_blocks, + "is_hma_enabled": is_hma_enabled, } compat_hash = hash_factors(factors) @@ -238,7 +243,7 @@ def compute_nixl_compatibility_hash( @dataclass class RemoteMeta: - block_ids: list[int] + block_ids: BlockIds host: str port: int engine_id: str @@ -247,9 +252,9 @@ class RemoteMeta: @dataclass class ReqMeta: - local_block_ids: list[int] + local_block_ids: BlockIds # To be used when logical block size does not match the kernel block size - local_physical_block_ids: list[int] + local_physical_block_ids: BlockIds tp_size: int remote: RemoteMeta | None = None @@ -264,7 +269,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): def _add_new_req( self, - local_block_ids: list[int], + local_block_ids: BlockIds, kv_transfer_params: dict[str, Any], ) -> ReqMeta: return ReqMeta( @@ -277,7 +282,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): def add_new_req_to_save( self, request_id: ReqId, - local_block_ids: list[int], + local_block_ids: BlockIds, kv_transfer_params: dict[str, Any], ): self.reqs_to_save[request_id] = self._add_new_req( @@ -287,7 +292,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): def add_new_req_to_recv( self, request_id: ReqId, - local_block_ids: list[int], + local_block_ids: BlockIds, kv_transfer_params: dict[str, Any], ): req = self._add_new_req(local_block_ids, kv_transfer_params) @@ -301,7 +306,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): self.reqs_to_recv[request_id] = req -class NixlConnector(KVConnectorBase_V1): +class NixlConnector(KVConnectorBase_V1, SupportsHMA): @property def prefer_cross_layer_blocks(self) -> bool: backend = get_current_attn_backend(self._vllm_config) @@ -326,22 +331,27 @@ class NixlConnector(KVConnectorBase_V1): self, vllm_config: VllmConfig, role: KVConnectorRole, - kv_cache_config: "KVCacheConfig | None" = None, + kv_cache_config: "KVCacheConfig", ): super().__init__(vllm_config, role, kv_cache_config) assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None + for group in kv_cache_config.kv_cache_groups: + if isinstance(group.kv_cache_spec, MambaSpec): + raise ValueError("NixlConnector does not support Mamba models.") self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.kv_transfer_config = vllm_config.kv_transfer_config if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: NixlConnectorScheduler | None = ( - NixlConnectorScheduler(vllm_config, self.engine_id) + NixlConnectorScheduler(vllm_config, self.engine_id, kv_cache_config) ) self.connector_worker: NixlConnectorWorker | None = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker(vllm_config, self.engine_id) + self.connector_worker = NixlConnectorWorker( + vllm_config, self.engine_id, kv_cache_config + ) ############################################################ # Class Methods @@ -392,10 +402,10 @@ class NixlConnector(KVConnectorBase_V1): assert self.connector_scheduler is not None return self.connector_scheduler.build_connector_meta(scheduler_output) - def request_finished( + def request_finished_all_groups( self, request: "Request", - block_ids: list[int], + block_ids: tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -518,10 +528,13 @@ class NixlConnector(KVConnectorBase_V1): class NixlConnectorScheduler: """Implementation of Scheduler side methods""" - def __init__(self, vllm_config: VllmConfig, engine_id: str): + def __init__( + self, vllm_config: VllmConfig, engine_id: str, kv_cache_config: "KVCacheConfig" + ): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.engine_id: EngineId = engine_id + self.kv_cache_config = kv_cache_config self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.side_channel_port = ( envs.VLLM_NIXL_SIDE_CHANNEL_PORT @@ -534,8 +547,18 @@ class NixlConnectorScheduler: self.use_host_buffer = ( vllm_config.kv_transfer_config.kv_buffer_device == "cpu" ) + self._is_hma_required = ( + not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager + # Also handle unlikely SW-only model case instead of checking num_groups>1. + and any( + not isinstance(g.kv_cache_spec, FullAttentionSpec) + for g in kv_cache_config.kv_cache_groups + ) + ) logger.info("Initializing NIXL Scheduler %s", engine_id) + if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: + logger.info("Hybrid Memory Allocator is enabled with NIXL") # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: threading.Thread | None = None @@ -545,7 +568,7 @@ class NixlConnectorScheduler: # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. - self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_recv: dict[ReqId, tuple[Request, BlockIds]] = {} self._reqs_need_save: dict[ReqId, Request] = {} # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} @@ -554,12 +577,54 @@ class NixlConnectorScheduler: # remote prefill or aborted. self._reqs_not_processed: set[ReqId] = set() + # Gather Sliding Window sizes for each kv cache group (if any) in number of + # blocks per KV cache group. This is used to clip the local attention window. + sw_sizes_tokens: list[tuple[int, int]] = [ + (g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size) + if isinstance(g.kv_cache_spec, SlidingWindowSpec) + else (0, self.block_size) + for g in kv_cache_config.kv_cache_groups + ] + # cdiv(n_tokens, block_size) gives blocks/window; add 1 to conservatively + # account for boundary overlap eg window isn't fully aligned with blocks. + self.blocks_per_sw = [ + cdiv(n_tokens, block_size) + 1 if n_tokens else 0 + for n_tokens, block_size in sw_sizes_tokens + ] + def shutdown(self): self._stop_event.set() if self._nixl_handshake_listener_t is not None: self._nixl_handshake_listener_t.join() self._nixl_handshake_listener_t = None + def get_sw_clipped_blocks(self, block_ids: BlockIds) -> BlockIds: + """ + Clip the number of blocks to the sliding window size for each kv cache group + that employs SWA. + This is necessary because the KV Cache manager initially allocates blocks for + the entire sequence length, and successively cleans up blocks that are outside + the window prior to the `request_finished_all_groups` hook. + """ + if len(block_ids) == 0 or not self._is_hma_required: + # No blocks to clip eg Full prefix cache hit or not a hybrid model. + return block_ids + # NOTE (NickLucche) This logic is currently handled at the connector level + # because offloading connectors might want to receive the whole sequence even + # for SWA groups. We will abstract this logic once the interface is more stable + assert len(block_ids) == len(self.blocks_per_sw), ( + "Number of KV cache groups must match" + ) + # For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged + return tuple( + [ + blocks[-self.blocks_per_sw[i] :] + if self.blocks_per_sw[i] > 0 + else blocks + for i, blocks in enumerate(block_ids) + ] + ) + def set_xfer_handshake_metadata( self, metadata: dict[int, KVConnectorHandshakeMetadata] ) -> None: @@ -707,12 +772,18 @@ class NixlConnectorScheduler: # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call # send_notif in _read_blocks to free the memory on the P. - local_block_ids = ( - blocks.get_unhashed_block_ids() + + unhashed_local_block_ids: BlockIds = ( + blocks.get_unhashed_block_ids_all_groups() if num_external_tokens > 0 - else [] + else () ) - # Get unhashed blocks to pull from remote. + local_block_ids = self.get_sw_clipped_blocks( + unhashed_local_block_ids + ) + + # Get unhashed blocks to pull from remote. Mind that a full prefix + # cache hit is indicated with an empty list. self._reqs_need_recv[request.request_id] = ( request, local_block_ids, @@ -753,9 +824,10 @@ class NixlConnectorScheduler: req = req_to_save assert req.kv_transfer_params is not None + clipped_block_id_groups = self.get_sw_clipped_blocks(new_block_id_groups) meta.add_new_req_to_save( request_id=req_id, - local_block_ids=new_block_id_groups[0], + local_block_ids=clipped_block_id_groups, kv_transfer_params=req.kv_transfer_params, ) assert scheduler_output.num_scheduled_tokens is not None @@ -786,7 +858,7 @@ class NixlConnectorScheduler: def request_finished( self, request: "Request", - block_ids: list[int], + block_ids: BlockIds, ) -> tuple[bool, dict[str, Any] | None]: """ Once a request is finished, determine whether request blocks @@ -828,7 +900,7 @@ class NixlConnectorScheduler: # TODO: check whether block_ids actually ever be 0. If not we could # remove the conditional below - delay_free_blocks = len(block_ids) > 0 + delay_free_blocks = any(len(group) > 0 for group in block_ids) if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion @@ -841,6 +913,11 @@ class NixlConnectorScheduler: self._reqs_need_send[request.request_id] = ( time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT ) + # NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones), + # trimming down after allocating for the whole sequence length. Empty + # blocks are always at the start of the list. + # Here we "unpad" blocks to send the actual remote blocks to be read. + block_ids = self.get_sw_clipped_blocks(block_ids) return delay_free_blocks, dict( do_remote_prefill=True, @@ -857,7 +934,9 @@ class NixlConnectorScheduler: class NixlConnectorWorker: """Implementation of Worker side methods""" - def __init__(self, vllm_config: VllmConfig, engine_id: str): + def __init__( + self, vllm_config: VllmConfig, engine_id: str, kv_cache_config: "KVCacheConfig" + ): if NixlWrapper is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") @@ -875,6 +954,14 @@ class NixlConnectorWorker: self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( "backends", ["UCX"] ) + self._is_hma_required = ( + not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager + and any( + not isinstance(g.kv_cache_spec, FullAttentionSpec) + for g in kv_cache_config.kv_cache_groups + ) + ) + self.kv_cache_config = kv_cache_config # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] @@ -1017,10 +1104,6 @@ class NixlConnectorWorker: self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config - # TODO(mgoin): remove this once we have hybrid memory allocator - # Optimization for models with local attention (Llama 4) - # List of block window sizes for each layer for local attention - self.block_window_per_layer: list[int | None] = [] self.use_mla = self.model_config.use_mla # Get the attention backend from the first layer @@ -1030,8 +1113,8 @@ class NixlConnectorWorker: self.backend_name = self.attn_backend.get_name() self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout - logger.debug("Detected attention backend %s", self.backend_name) - logger.debug("Detected kv cache layout %s", self.kv_cache_layout) + logger.info("Detected attention backend %s", self.backend_name) + logger.info("Detected kv cache layout %s", self.kv_cache_layout) # lazy initialized in register_kv_caches self.compat_hash: str | None = None @@ -1238,9 +1321,15 @@ class NixlConnectorWorker: "remote_request_id": meta.remote.request_id, "remote_host": meta.remote.host, "remote_port": meta.remote.port, - "num_local_blocks": len(meta.local_block_ids), - "num_remote_blocks": len(meta.remote.block_ids), - "local_block_ids_sample": meta.local_block_ids[:10], + "num_local_blocks": sum( + len(group) for group in meta.local_block_ids + ), + "num_remote_blocks": sum( + len(group) for group in meta.remote.block_ids + ), + "local_block_ids_sample": meta.local_block_ids[0][:10] + if meta.local_block_ids + else [], } ) @@ -1301,8 +1390,10 @@ class NixlConnectorWorker: error=e, meta=meta, ) - if req_meta := self._recving_metadata.get(req_id): - self._invalid_block_ids.update(req_meta.local_block_ids) + if ( + req_meta := self._recving_metadata.get(req_id) + ) and not self._is_hma_required: + self._invalid_block_ids.update(req_meta.local_block_ids[0]) self._failed_recv_reqs.add(req_id) fut.add_done_callback(request_ready) @@ -1370,6 +1461,10 @@ class NixlConnectorWorker: for cache in cache_list: base_addr = cache.data_ptr() if base_addr in seen_base_addresses: + # NOTE (NickLucche) HMA employs memory pooling to share tensors + # across groups. This results in skipping all tensors but the ones + # pointed to by group0. Also, generally we will have more blocks + # per tensor but fewer regions. continue logger.debug( @@ -1457,28 +1552,6 @@ class NixlConnectorWorker: self.register_local_xfer_handler(self.block_size) ) - # TODO(mgoin): Hybrid memory allocator is currently disabled for - # models with local attention (Llama 4). Can remove this once enabled. - if self.model_config.hf_config.model_type == "llama4": - from transformers import Llama4TextConfig - - assert isinstance(self.model_config.hf_text_config, Llama4TextConfig) - llama4_config = self.model_config.hf_text_config - no_rope_layers = llama4_config.no_rope_layers - chunk_size = llama4_config.attention_chunk_size - chunk_block_size = math.ceil(chunk_size / self.block_size) - for layer_idx in range(self.num_layers): - # no_rope_layers[layer_idx] == 0 means NoPE (global) - # Any other value means RoPE (local chunked) - is_local_attention = no_rope_layers[layer_idx] != 0 - block_window = chunk_block_size if is_local_attention else None - self.block_window_per_layer.append(block_window) - logger.debug( - "Llama 4 block window per layer mapping: %s", - self.block_window_per_layer, - ) - assert len(self.block_window_per_layer) == self.num_layers - # After KV Caches registered, listen for new connections. agent_metadata = NixlAgentMetadata( engine_id=self.engine_id, @@ -1767,6 +1840,11 @@ class NixlConnectorWorker: # Num kv_heads > tp_size and P TP > D TP case, not supported assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)) + if self._is_hma_required: + assert block_size_ratio == 1, ( + "HMA does not support different remote block size yet" + ) + kv_cache_layout = ( self.kv_cache_layout if not self.use_host_buffer @@ -1781,6 +1859,9 @@ class NixlConnectorWorker: "Remote is HND and local is NHD, enabled additional permute " "on local device KV." ) + assert not self._is_hma_required, ( + "HMA does not support block size post processing" + ) self.enable_permute_local_kv = True else: raise RuntimeError( @@ -1836,13 +1917,15 @@ class NixlConnectorWorker: assert self.copy_blocks is not None local_block_ids = meta.local_physical_block_ids - self.copy_blocks( - self.host_xfer_buffers, - self.device_kv_caches, - local_block_ids, - local_block_ids, - "h2d", - ) + # TODO (NickLucche) D2H<>H2D ops could benefit from coalescing io across groups + for group_block_ids in local_block_ids: + self.copy_blocks( + self.host_xfer_buffers, + self.device_kv_caches, + group_block_ids, + group_block_ids, + "h2d", + ) if logger.isEnabledFor(logging.DEBUG): logger.debug( "synced recved kv of request[%s] to device kv buffer," @@ -1868,13 +1951,14 @@ class NixlConnectorWorker: ",".join(map(str, meta.local_physical_block_ids)), ) # blocking - self.copy_blocks( - self.device_kv_caches, - self.host_xfer_buffers, - meta.local_physical_block_ids, - meta.local_physical_block_ids, - "d2h", - ) + for group_block_ids in meta.local_physical_block_ids: + self.copy_blocks( + self.device_kv_caches, + self.host_xfer_buffers, + group_block_ids, + group_block_ids, + "d2h", + ) def post_process_device_kv_on_receive( self, @@ -1973,8 +2057,9 @@ class NixlConnectorWorker: if not self.use_mla and ( block_size_ratio > 1 or self.enable_permute_local_kv ): + assert not self._is_hma_required block_ids_for_blocksize_post_process[block_size_ratio].append( - meta.local_physical_block_ids + meta.local_physical_block_ids[0] ) for ( block_size_ratio, @@ -2106,8 +2191,9 @@ class NixlConnectorWorker: handle: The transfer handle. """ # Use .get() here as the metadata cleanup is handled by get_finished() - if meta := self._recving_metadata.get(req_id): - self._invalid_block_ids.update(meta.local_block_ids) + # TODO (NickLucche) handle failed transfer for HMA. + if (meta := self._recving_metadata.get(req_id)) and not self._is_hma_required: + self._invalid_block_ids.update(meta.local_block_ids[0]) self.nixl_wrapper.release_xfer_handle(handle) self.xfer_stats.record_failed_transfer() @@ -2230,8 +2316,8 @@ class NixlConnectorWorker: def _read_blocks( self, - local_block_ids: list[int], - remote_block_ids: list[int], + local_block_ids: BlockIds, + remote_block_ids: BlockIds, dst_engine_id: str, request_id: str, remote_request_id: str, @@ -2246,22 +2332,30 @@ class NixlConnectorWorker: assert self.kv_topo is not None block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: - local_block_ids = self.get_mapped_blocks( - np.asarray(local_block_ids), block_size_ratio - ) - if len(local_block_ids) > len(remote_block_ids): + # TODO (NickLucche) assume HMA is off. Change to handle multiple KV groups. + assert not self._is_hma_required + local_block_ids0 = local_block_ids[0] if local_block_ids else [] + remote_block_ids0 = remote_block_ids[0] + local_block_ids_mapped = self.get_mapped_blocks( + np.asarray(local_block_ids0), block_size_ratio + ).tolist() + if len(local_block_ids_mapped) > len(remote_block_ids0): # NOTE: # get_mapped_blocks will always expand block_ids for n times. # ex: # prefill block_ids with block_size as 4: # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # Local decode block_ids with block_size as 16: [1, 2, 3] - # expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to + # expanded decode block_ids with get_mapped_blocks from [1, 2, 3] to # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Then we clip local to align with prefill # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - local_block_ids = local_block_ids[: len(remote_block_ids)] + local_block_ids_mapped = local_block_ids_mapped[ + : len(remote_block_ids0) + ] + local_block_ids = [local_block_ids_mapped] if local_block_ids_mapped else [] + remote_block_ids = [remote_block_ids0] # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -2269,8 +2363,7 @@ class NixlConnectorWorker: # then we will need to have the staging blocks on the remote side. # NOTE(rob): according to nvidia the staging blocks are used to - # saturate IB with heterogeneous TP sizes. We should remove the staging - # blocks until we are ready. + # saturate IB with heterogeneous TP sizes. # Number of D TP workers that will read from dst P. Propagate info # on notification so that dst worker can wait before freeing blocks. @@ -2278,8 +2371,8 @@ class NixlConnectorWorker: # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. - num_local_blocks = len(local_block_ids) - if num_local_blocks == 0: + if len(local_block_ids) == 0: + # A full prefix cache hit is indicated with an empty list. agent_name = self._remote_agents[dst_engine_id][remote_rank] try: self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) @@ -2297,66 +2390,34 @@ class NixlConnectorWorker: self.xfer_stats.record_failed_notification() return - # Partial prefix cache hit: just read uncomputed blocks. - num_remote_blocks = len(remote_block_ids) - assert num_local_blocks <= num_remote_blocks - if num_local_blocks < num_remote_blocks: - remote_block_ids = remote_block_ids[-num_local_blocks:] + assert ( + len(remote_block_ids) + == len(local_block_ids) + == len(self.kv_cache_config.kv_cache_groups) + ) + remote_block_ids = list(remote_block_ids) + for i, remote_group in enumerate(remote_block_ids): + num_remote_blocks = len(remote_group) + num_local_blocks = len(local_block_ids[i]) + assert num_local_blocks <= num_remote_blocks + # Partial prefix cache hit: just read uncomputed blocks. + if num_local_blocks < num_remote_blocks: + remote_block_ids[i] = remote_group[-num_local_blocks:] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # corresponding rank. With heterogeneous TP, fixing D>P, the D tp # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - local_block_descs_ids: np.ndarray - remote_block_descs_ids: np.ndarray - - if not self.block_window_per_layer: - # Default case: assume global attention - remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, - remote_block_ids, - ) - local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, - local_block_ids, - block_size_ratio=block_size_ratio, - ) - else: - # TODO(mgoin): remove this once we have hybrid memory allocator - # Optimization for models with local attention (Llama 4) - local_descs_list = [] - remote_descs_list = [] - for layer_idx, block_window in enumerate(self.block_window_per_layer): - # For each layer: - if block_window is None: - # If not chunked, we just use the - # full block lists (global attention) - layer_local_block_ids = local_block_ids - layer_remote_block_ids = remote_block_ids - else: - # If chunked, get the last block_window blocks - layer_local_block_ids = local_block_ids[-block_window:] - layer_remote_block_ids = remote_block_ids[-block_window:] - - # Get descs ids for the layer. - layer_local_desc_ids = self._get_block_descs_ids( - self.engine_id, - layer_local_block_ids, - layer_idx, - block_size_ratio=block_size_ratio, - ) - layer_remote_desc_ids = self._get_block_descs_ids( - dst_engine_id, - layer_remote_block_ids, - layer_idx, - ) - - local_descs_list.append(layer_local_desc_ids) - remote_descs_list.append(layer_remote_desc_ids) - - local_block_descs_ids = np.concatenate(local_descs_list) - remote_block_descs_ids = np.concatenate(remote_descs_list) + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, + remote_block_ids, + ) + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, + local_block_ids, + block_size_ratio=block_size_ratio, + ) assert len(local_block_descs_ids) == len(remote_block_descs_ids) @@ -2387,14 +2448,18 @@ class NixlConnectorWorker: dst_engine_id=dst_engine_id, remote_rank=remote_rank, ) - if meta := self._recving_metadata.get(request_id): - self._invalid_block_ids.update(meta.local_block_ids) + if ( + meta := self._recving_metadata.get(request_id) + ) and not self._is_hma_required: + self._invalid_block_ids.update(meta.local_block_ids[0]) self.xfer_stats.record_failed_transfer() if handle is not None: self.nixl_wrapper.release_xfer_handle(handle) self._failed_recv_reqs.add(request_id) - def get_mapped_blocks(self, block_ids, block_size_ratio): + def get_mapped_blocks( + self, block_ids: np.ndarray, block_size_ratio: int + ) -> np.ndarray: """ Calculates the new set of block IDs by mapping every element in the (potentially sparse) input array. @@ -2416,41 +2481,32 @@ class NixlConnectorWorker: def _get_block_descs_ids( self, engine_id: str, - block_ids: list[int], - layer_idx: int | None = None, + block_ids: BlockIds, block_size_ratio: float | None = None, ) -> np.ndarray: """ Get the descs ids for a set of block ids. - If layer_idx is provided, we use the region_ids for the given layer. - Otherwise, we use all regions. + When HMA is enabled number of descriptors across kv cache groups might differ. + A single flattened array is returned for all groups anyway. """ - if layer_idx is None: - region_ids = np.arange(self.num_regions) - else: - assert layer_idx < self.num_layers - if self.num_layers < self.num_regions: - # If we have more regions than layers, we assume that - # the regions are organized as [K0, V0, K1, V1, ...] - # and we select K_i and V_i - assert 2 * self.num_layers == self.num_regions - region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2) - else: - # Otherwise, we assume we have MLA and select i-th layer - assert self.num_layers == self.num_regions - region_ids = np.arange(layer_idx, layer_idx + 1) - + region_ids = np.arange(self.num_regions) + # NOTE (NickLucche) With HMA, every kv group has the same number of layers and + # layers from different groups share the same kv tensor. + # eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions, + # same for [3], but group0-group1 blocks will always differ (different areas). + # Therefore we can just flatten the block_ids and compute the descs ids for all + # groups at once. num_blocks = self.dst_num_blocks[engine_id] if block_size_ratio is not None: num_blocks = int(num_blocks * block_size_ratio) # Compute the desc ids for each block. region_ids = region_ids[:, None] - block_ids = np.array(block_ids)[None, :] + block_ids = np.concatenate(block_ids)[None, :] descs_ids = region_ids * num_blocks + block_ids return descs_ids.flatten() - def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]: + def _logical_to_kernel_block_ids(self, block_ids: BlockIds) -> BlockIds: """ Convert logical block ids to kernel physical block ids. This is required when the logical block size (the one set by the user) @@ -2459,13 +2515,17 @@ class NixlConnectorWorker: if self._physical_blocks_per_logical_kv_block == 1: # Noop when physical and logical block sizes are the same return block_ids - block_ids_np = np.array(block_ids) block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( 1, -1 ) - return BlockTable.map_to_kernel_blocks( - block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange - ).tolist() + return [ + BlockTable.map_to_kernel_blocks( + np.array(group), + self._physical_blocks_per_logical_kv_block, + block_arange, + ).tolist() + for group in block_ids + ] def get_backend_aware_kv_block_len(self, layer_idx: int) -> int: """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 7f8d80475..ee198a57f 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -84,6 +84,18 @@ class KVCacheBlocks: assert len(self.blocks) == 1, "Only one group is supported" return [block.block_id for block in self.blocks[0] if block.block_hash is None] + def get_unhashed_block_ids_all_groups(self) -> list[list[int]]: + """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" + # Skip padding blocks. + return [ + [ + block.block_id + for block in group + if block.block_hash is None and not block.is_null + ] + for group in self.blocks + ] + def new_empty(self) -> "KVCacheBlocks": """ Creates a new KVCacheBlocks instance with no blocks.