diff --git a/tests/v1/kv_connector/unit/test_mooncake_connector.py b/tests/v1/kv_connector/unit/test_mooncake_connector.py new file mode 100644 index 000000000..f21f8ecdc --- /dev/null +++ b/tests/v1/kv_connector/unit/test_mooncake_connector.py @@ -0,0 +1,756 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import contextlib +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import torch +import zmq.asyncio + +from vllm.config import set_current_vllm_config +from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector import ( + KVConnectorRole, + MooncakeConnector, + MooncakeConnectorMetadata, + MooncakeXferMetadata, + MooncakeXferResponse, + MooncakeXferResponseStatus, + PullReqMeta, + SendBlockMeta, +) +from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import ( + MooncakeBootstrapServer, +) +from vllm.utils.network_utils import get_open_port +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.request import RequestStatus + +from .utils import create_request, create_scheduler, create_vllm_config + + +class FakeMooncakeWrapper: + """Mock Mooncake TransferEngine for unit testing environments.""" + + def __init__(self, *args, **kwargs): + pass + + def initialize(self, local_hostname, metadata_server, protocol, device_name) -> int: + return 0 + + def get_rpc_port(self) -> int: + return 12345 + + def batch_transfer_sync_write( + self, target_hostname, buffers, peer_buffer_addresses, lengths + ) -> int: + return 0 + + def batch_register_memory(self, buffer_addresses, capacities) -> int: + return 0 + + +def test_basic_interface(): + """Unit test for basic MooncakeConnector interface functionality.""" + + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_consumer" + ) + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) + request_id = request.request_id + request.kv_transfer_params.update( + { + "transfer_id": request_id, + "remote_bootstrap_addr": 54321, + } + ) + + scheduler.add_request(request) + + # Remote Prefill, triggers NixlConnectorMetadata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, MooncakeConnectorMetadata) + + assert len(kv_connector_metadata.reqs_to_recv) == 1 + assert request_id in kv_connector_metadata.reqs_to_recv["my-engine-id"] + req_meta = kv_connector_metadata.reqs_to_recv["my-engine-id"][request_id] + + for block_id, block in zip( + req_meta.local_block_ids, + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ], + ): + assert block_id == block.block_id + + +def test_prompt_less_than_block_size(): + """Test that we can handle case where prompt is < block.""" + + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_consumer" + ) + scheduler = create_scheduler(vllm_config) + + # Half of a block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_TOKENS = int(BLOCK_SIZE * 0.5) + + # Request will have 1 partial remote block. + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + num_remote_blocks=1, + ) + request.kv_transfer_params.update( + { + "transfer_id": request.request_id, + "remote_bootstrap_addr": 54321, + } + ) + + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + + # This request will read async. + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, MooncakeConnectorMetadata) + assert len(kv_connector_metadata.reqs_to_recv["my-engine-id"]) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 + + +@pytest.fixture +def bootstrap_server(): + """Fixture to launch and cleanup a Mooncake Bootstrap HTTP Server.""" + + port = get_open_port() + server = MooncakeBootstrapServer("127.0.0.1", port) + server.start() + yield server + server.shutdown() + + +@pytest.mark.asyncio +async def test_bootstrap_server(bootstrap_server: MooncakeBootstrapServer): + """ + Tests the bootstrap server's api for worker registration and querying. + + Validates DP/TP/PP rank indexing and error handling for duplicate registrations. + """ + + import httpx + + base_url = f"http://127.0.0.1:{bootstrap_server.port}" + + # Query when empty + async with httpx.AsyncClient() as client: + response = await client.get(f"{base_url}/query") + assert response.status_code == 200 + assert response.json() == {} + + # Register a worker + payload1 = { + "engine_id": "eng-1", + "dp_rank": 0, + "tp_rank": 0, + "pp_rank": 0, + "addr": "tcp://1.1.1.1:1111", + } + async with httpx.AsyncClient() as client: + response = await client.post(f"{base_url}/register", json=payload1) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + # Query after registration + async with httpx.AsyncClient() as client: + response = await client.get(f"{base_url}/query") + assert response.status_code == 200 + data = response.json() + assert "0" in data + assert data["0"]["engine_id"] == "eng-1" + assert data["0"]["worker_addr"]["0"]["0"] == "tcp://1.1.1.1:1111" + + # Test failure: re-registering the same worker + async with httpx.AsyncClient() as client: + response = await client.post(f"{base_url}/register", json=payload1) + assert response.status_code == 400 + assert "is already registered" in response.text + + # Test failure: engine_id mismatch for same dp_rank + payload3_fail = { + "engine_id": "eng-2", + "dp_rank": 0, + "tp_rank": 1, + "pp_rank": 0, + "addr": "tcp://3.3.3.3:3333", + } + async with httpx.AsyncClient() as client: + response = await client.post(f"{base_url}/register", json=payload3_fail) + assert response.status_code == 400 + assert "Engine ID mismatch" in response.text + + +def test_scheduler_request_finished(): + """ + Tests the scheduler-side logic when a request finishes. + + Differentiates between 'Finished' (requires transfer) + and 'Aborted' (immediate free). + """ + + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_producer" + ) + scheduler = create_scheduler(vllm_config) + scheduler_connector = scheduler.get_kv_connector().connector_scheduler + + request = create_request(request_id=1, do_remote_decode=True) + request.kv_transfer_params["transfer_id"] = request.request_id + + # Case: Capped length (Successful prefill, need to send to decoder) + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + delay_free, _ = scheduler_connector.request_finished(request, block_ids=[10, 11]) + assert delay_free is True + assert "id-1" in scheduler_connector._reqs_need_send + assert scheduler_connector._reqs_need_send["id-1"][1] == [10, 11] + + # Case: Aborted (No need to transfer, free blocks immediately) + scheduler_connector._reqs_need_send.clear() + request.status = RequestStatus.FINISHED_ABORTED + delay_free, _ = scheduler_connector.request_finished(request, block_ids=[12]) + assert delay_free is False + assert len(scheduler_connector._reqs_need_send) == 0 + assert "id-1" in scheduler_connector._reqs_not_processed + + +@contextlib.contextmanager +def patch_worker_dependencies(): + """Helper to mock all distributed and network dependencies for Worker tests.""" + + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.TransferEngine", + FakeMooncakeWrapper, + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.get_ip", + return_value="127.0.0.1", + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.get_tensor_model_parallel_rank", + return_value=0, + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.get_tensor_model_parallel_world_size", + return_value=1, + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.get_pp_group" + ) as mock_pp, + patch("vllm.distributed.parallel_state.is_local_first_rank", return_value=True), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.should_launch_bootstrap_server", + return_value=False, + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.make_zmq_socket" + ) as mock_make_zmq, + patch("httpx.AsyncClient") as mock_async_client, + ): + # Mock PP group + mock_pp_group = MagicMock() + mock_pp_group.rank_in_group = 0 + mock_pp.return_value = mock_pp_group + + # Mock ZMQ socket + mock_socket_object = AsyncMock() + mock_socket_object.setsockopt = MagicMock() + mock_socket_ctx = MagicMock() + mock_socket_ctx.__enter__.return_value = mock_socket_object + mock_make_zmq.return_value = mock_socket_ctx + + # Mock httpx client + mock_http_client_instance = AsyncMock() + mock_async_client.return_value = mock_http_client_instance + + yield { + "mock_make_zmq": mock_make_zmq, + "mock_socket_object": mock_socket_object, + "mock_async_client": mock_async_client, + "mock_http_client": mock_http_client_instance, + } + + +@pytest.mark.asyncio +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.TransferEngine", + FakeMooncakeWrapper, +) +async def test_kv_producer(monkeypatch): + """ + Simulates a Producer Worker (Prefiller) receiving a transfer request + from a Consumer (Decoder). + + Verifies memory offset calculation: ptr = base_addr + block_id * block_len. + """ + + monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5") + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_producer" + ) + + with set_current_vllm_config(vllm_config), patch_worker_dependencies(): + prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER) + prefill_worker = prefill_connector.connector_worker + prefill_worker.kv_caches_base_addr = [0x1000] + block_len = 4096 + prefill_worker.block_len_per_layer = [block_len] + + # Override loop to use current test loop + origin_sender_loop = prefill_worker.sender_loop + prefill_worker.sender_loop = asyncio.get_event_loop() + + # A request is finished on Producer and ready to be sent. + transfer_id = "xfer-req-1" + send_meta = SendBlockMeta( + p_req_id="p-req-1", + transfer_id=transfer_id, + local_block_ids=[10, 11], + ready=asyncio.Event(), + ) + prefill_worker.reqs_need_send[transfer_id] = send_meta + send_meta.ready.set() + + # Remote consumer request metadata + xfer_meta = MooncakeXferMetadata( + remote_hostname="consumer-host", + remote_port=54321, + remote_tp_size=1, + remote_tp_rank=0, + req_blocks={"d-req-1": (transfer_id, [20, 21])}, + kv_caches_base_addr=[0x2000], + block_lens=[block_len], + ) + + mock_socket = AsyncMock(spec=zmq.asyncio.Socket) + mock_socket.send_multipart = AsyncMock() + identity = b"consumer-id" + + with patch.object( + prefill_worker, "_send_blocks", return_value=0 + ) as mock_send_blocks: + # Normal case: 2 blocks to 2 blocks + # Worker processes the consumer's request + await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) + # Verify transfer parameters are correct + src_ptr = 0x1000 + 10 * block_len + dst_ptr = 0x2000 + 20 * block_len + length = 2 * block_len + mock_send_blocks.assert_called_once_with( + "consumer-host:54321", [src_ptr], [dst_ptr], [length] + ) + mock_socket.send_multipart.assert_called_once() + + # Verify the response sent back to the consumer + sent_call = mock_socket.send_multipart.call_args[0][0] + sent_identity, sent_payload = sent_call + assert sent_identity == identity + response = prefill_worker._xfer_resp_decoder.decode(sent_payload) + assert response.status == MooncakeXferResponseStatus.FINISH + assert response.ok_reqs == ["d-req-1"] + + # Verify internal state cleanup + assert transfer_id not in prefill_worker.reqs_need_send + assert "p-req-1" in prefill_worker.finished_sending_reqs + + # More cases: + # Consumer only needs 1 block (less than P) + mock_send_blocks.reset_mock() + mock_socket.send_multipart.reset_mock() + prefill_worker.reqs_need_send[transfer_id] = send_meta + send_meta.sent = 0 + send_meta.ready.set() + xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20]) + # Worker processes the consumer's request + await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) + # Verify transfer parameters are correct: 11 to 20 + src_ptr = 0x1000 + 11 * block_len + dst_ptr = 0x2000 + 20 * block_len + length = 1 * block_len + mock_send_blocks.assert_called_once_with( + "consumer-host:54321", [src_ptr], [dst_ptr], [length] + ) + mock_socket.send_multipart.assert_called_once() + + # Consumer needs 3 blocks (more than P, error case) + mock_send_blocks.reset_mock() + mock_socket.send_multipart.reset_mock() + prefill_worker.reqs_need_send[transfer_id] = send_meta + send_meta.sent = 0 + send_meta.ready.set() + xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21, 22]) + # Worker processes the consumer's request + await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) + # This should not be called because error. + mock_send_blocks.assert_not_called() + mock_socket.send_multipart.assert_called_once() + _, sent_payload = mock_socket.send_multipart.call_args[0][0] + response = prefill_worker._xfer_resp_decoder.decode(sent_payload) + assert response.err_msg == "P num blocks less than D" + assert response.err_reqs == ["d-req-1"] + + # Timeout + mock_send_blocks.reset_mock() + mock_socket.send_multipart.reset_mock() + prefill_worker.reqs_need_send[transfer_id] = send_meta + send_meta.sent = 0 + send_meta.ready.clear() + xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21]) + # Worker processes the consumer's request + await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) + # This should not be called because timeout. + mock_send_blocks.assert_not_called() + mock_socket.send_multipart.assert_called_once() + _, sent_payload = mock_socket.send_multipart.call_args[0][0] + response = prefill_worker._xfer_resp_decoder.decode(sent_payload) + assert response.err_msg == "Timeout waiting for P side ready." + assert response.err_reqs == ["d-req-1"] + + # Transfer error + with patch.object( + prefill_worker, "_send_blocks", return_value=123 + ) as mock_send_blocks: + mock_socket.send_multipart.reset_mock() + prefill_worker.reqs_need_send[transfer_id] = send_meta + send_meta.sent = 0 + send_meta.ready.set() + xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21]) + # Worker processes the consumer's request + await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) + mock_send_blocks.assert_called_once() + mock_socket.send_multipart.assert_called_once() + _, sent_payload = mock_socket.send_multipart.call_args[0][0] + response = prefill_worker._xfer_resp_decoder.decode(sent_payload) + assert response.err_msg == "Mooncake transfer engine returned 123" + assert response.err_reqs == ["d-req-1"] + + # Clean up + prefill_worker.sender_loop = origin_sender_loop + prefill_worker.shutdown() + + +@pytest.mark.asyncio +async def test_kv_consumuer(monkeypatch): + """ + Simulates a Consumer Worker (Decoder) initiating a pull from a Producer. + + Verifies that MooncakeXferMetadata is correctly serialized and sent via ZMQ. + """ + + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_consumer" + ) + + with set_current_vllm_config(vllm_config), patch_worker_dependencies() as mocks: + decode_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER) + decode_worker = decode_connector.connector_worker + decode_worker.kv_caches_base_addr = [0x1000] + decode_worker.rpc_port = 54321 + + # A request to pull data arrives. + pull_metas = { + "d-req-1": PullReqMeta( + d_req_id="d-req-1", + transfer_id="xfer-req-1", + local_block_ids=[100, 101], + remote_engine_id="p-engine", + remote_bootstrap_addr="http://bootstrap:33333", + pull_tasks_count=1, + ) + } + decode_worker._remote_agents = {"p-engine": {0: {0: "tcp://producer:1234"}}} + decode_worker._tp_size["p-engine"] = 1 + + # Mock the response from the producer. + mock_response = MooncakeXferResponse( + status=MooncakeXferResponseStatus.FINISH, ok_reqs=["d-req-1"] + ) + encoded_response = decode_worker._encoder.encode(mock_response) + mocks["mock_socket_object"].recv.return_value = encoded_response + + # Trigger the receive logic. + decode_worker.receive_kv("p-engine", pull_metas) + await asyncio.sleep(1) # Allow async task to run + + # Verify the metadata sent to the producer. + mocks["mock_make_zmq"].assert_called_with( + decode_worker.async_zmq_ctx, + "tcp://producer:1234", + zmq.DEALER, + bind=False, + linger=0, + ) + sent_payload = mocks["mock_socket_object"].send.call_args[0][0] + sent_meta = decode_worker._xfer_meta_decoder.decode(sent_payload) + + assert sent_meta.remote_hostname == "127.0.0.1" + assert sent_meta.remote_port == 54321 + assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [100, 101]) + + # Verify internal state is updated correctly. + assert "d-req-1" in decode_worker.finished_recving_reqs + + # Clean up + decode_worker.shutdown() + + +@pytest.mark.asyncio +async def test_worker_get_finished_timeout(monkeypatch): + """Tests the cleanup mechanism for requests.""" + + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_producer" + ) + with set_current_vllm_config(vllm_config), patch_worker_dependencies(): + prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER) + prefill_worker = prefill_connector.connector_worker + + # Add an expired request (expire_time is in the past). + prefill_worker.reqs_need_send["tx-expired"] = SendBlockMeta( + p_req_id="p-req-expired", + transfer_id="tx-expired", + local_block_ids=[1, 2], + ready=MagicMock(), + expire_time=time.perf_counter() - 100, + ) + + # Add a non-expired request. + prefill_worker.reqs_need_send["tx-active"] = SendBlockMeta( + p_req_id="p-req-active", + transfer_id="tx-active", + local_block_ids=[3, 4], + ready=MagicMock(), + expire_time=time.perf_counter() + 100, + ) + + finished_reqs = await prefill_worker.fetch_finished_sending_reqs() + + assert "p-req-expired" in finished_reqs + assert "p-req-active" not in finished_reqs + assert "tx-expired" not in prefill_worker.reqs_need_send + assert "tx-active" in prefill_worker.reqs_need_send + + +def test_register_kv_caches(): + """Tests the memory registration logic with the underlying Mooncake engine.""" + + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_consumer" + ) + + with ( + set_current_vllm_config(vllm_config), + patch_worker_dependencies(), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Event" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Thread" + ) as mock_thread, + ): + connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER) + worker = connector.connector_worker + mock_thread.return_value.is_alive.return_value = False + + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + tensor1 = torch.zeros(*kv_cache_shape, dtype=torch.float16) + tensor2 = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = {"layer0": tensor1, "layer1": tensor2} + + with patch.object( + worker.engine, "batch_register_memory", return_value=0 + ) as mock_batch_register: + connector.register_kv_caches(kv_caches) + + mock_batch_register.assert_called_once() + registered_ptrs, registered_lens = mock_batch_register.call_args[0] + expected_ptrs = { + tensor.data_ptr() + for kv_pair in kv_caches.values() + for tensor in kv_pair + } + assert set(registered_ptrs) == expected_ptrs + assert set(registered_lens) == {tensor1[0].nbytes} + + # Verify block_len_per_layer is set correctly. + assert len(worker.block_len_per_layer) == len(registered_ptrs) + for bl in worker.block_len_per_layer: + assert bl == tensor1[0].nbytes // tensor1.shape[1] + + +@pytest.mark.asyncio +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake." + "mooncake_connector.TransferEngine", + FakeMooncakeWrapper, +) +@pytest.mark.parametrize("d_tp_size", [1, 4], ids=["p_tp2_d_tp1", "p_tp2_d_tp4"]) +async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): + """ + Tests heterogeneous TP support in the producer transfer path. + + Verifies correct pointer and offset calculation when producer TP=2 + sends to consumer with TP=1 (P>D) or TP=4 (P D TP=1: one D rank receives; dst_offset based on P rank + - P TP=2 < D TP=4: two D ranks receive; src_offset based on D rank + """ + + P_TP_SIZE = 2 + P_TP_RANK = 0 + LOCAL_BLOCK_LEN = 4096 + + local_block_len = LOCAL_BLOCK_LEN + remote_block_len = LOCAL_BLOCK_LEN * P_TP_SIZE // d_tp_size + + monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5") + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_producer" + ) + + with set_current_vllm_config(vllm_config), patch_worker_dependencies(): + prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER) + prefill_worker = prefill_connector.connector_worker + + # Override TP rank/size to simulate P TP=2 + prefill_worker.tp_rank = P_TP_RANK + prefill_worker.tp_size = P_TP_SIZE + # Update shared dict so kv_topo sees correct TP size + prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE + prefill_worker.kv_topo.tp_rank = P_TP_RANK + + prefill_worker.kv_caches_base_addr = [0x1000] + prefill_worker.block_len_per_layer = [local_block_len] + + origin_sender_loop = prefill_worker.sender_loop + prefill_worker.sender_loop = asyncio.get_event_loop() + + transfer_id = "xfer-hetero-1" + local_block_ids = [10, 11] + send_meta = SendBlockMeta( + p_req_id="p-req-h1", + transfer_id=transfer_id, + local_block_ids=local_block_ids, + ready=asyncio.Event(), + ) + prefill_worker.reqs_need_send[transfer_id] = send_meta + send_meta.ready.set() + + # Compute target D ranks using the production code path + target_d_ranks = prefill_worker.kv_topo.get_target_remote_ranks(d_tp_size) + + mock_socket = AsyncMock(spec=zmq.asyncio.Socket) + mock_socket.send_multipart = AsyncMock() + identity = b"consumer-hetero" + + # Assign different remote block IDs per D rank + d_rank_remote_blocks = { + rank: [20 + i * 10, 21 + i * 10] for i, rank in enumerate(target_d_ranks) + } + + with patch.object( + prefill_worker, "_send_blocks", return_value=0 + ) as mock_send_blocks: + for d_rank in target_d_ranks: + remote_block_ids = d_rank_remote_blocks[d_rank] + xfer_meta = MooncakeXferMetadata( + remote_hostname="consumer-host", + remote_port=54321, + remote_tp_size=d_tp_size, + remote_tp_rank=d_rank, + req_blocks={ + f"d-req-h1-r{d_rank}": ( + transfer_id, + remote_block_ids, + ) + }, + kv_caches_base_addr=[0x2000], + block_lens=[remote_block_len], + ) + + mock_send_blocks.reset_mock() + mock_socket.send_multipart.reset_mock() + + await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) + + # Verify _send_blocks was called + mock_send_blocks.assert_called_once() + call_args = mock_send_blocks.call_args[0] + src_ptrs = call_args[1] + dst_ptrs = call_args[2] + lengths = call_args[3] + + # Heterogeneous TP: blocks cannot be coalesced because + # local and remote block_lens differ + assert len(src_ptrs) == len(local_block_ids) + assert len(dst_ptrs) == len(local_block_ids) + assert len(lengths) == len(local_block_ids) + + # Compute expected offsets based on TP ratio + if d_tp_size <= P_TP_SIZE: + tp_ratio = P_TP_SIZE // d_tp_size + expected_src_off = 0 + expected_dst_off = (P_TP_RANK % tp_ratio) * local_block_len + expected_xfer_len = local_block_len + else: + ratio_abs = d_tp_size // P_TP_SIZE + expected_src_off = (d_rank % ratio_abs) * remote_block_len + expected_dst_off = 0 + expected_xfer_len = remote_block_len + + for idx, (lblk, rblk) in enumerate( + zip(local_block_ids, remote_block_ids) + ): + assert src_ptrs[idx] == ( + 0x1000 + lblk * local_block_len + expected_src_off + ) + assert dst_ptrs[idx] == ( + 0x2000 + rblk * remote_block_len + expected_dst_off + ) + assert lengths[idx] == expected_xfer_len + + # Verify successful response sent back to consumer + mock_socket.send_multipart.assert_called_once() + _, sent_payload = mock_socket.send_multipart.call_args[0][0] + response = prefill_worker._xfer_resp_decoder.decode(sent_payload) + assert response.status == MooncakeXferResponseStatus.FINISH + assert response.ok_reqs == [f"d-req-h1-r{d_rank}"] + + # After serving all D ranks, the request should be complete + assert transfer_id not in prefill_worker.reqs_need_send + assert "p-req-h1" in prefill_worker.finished_sending_reqs + + prefill_worker.sender_loop = origin_sender_loop + prefill_worker.shutdown() diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 1e2a05f0e..75dc47947 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -100,6 +100,8 @@ def create_vllm_config( hf_overrides: dict[str, Any] | None = None, attention_backend: str | None = None, kv_load_failure_policy: Literal["recompute", "fail"] = "fail", + kv_connector: str = "NixlConnector", + kv_role: str = "kv_both", ) -> VllmConfig: """Initialize VllmConfig For Testing.""" model_config = ModelConfig( @@ -124,8 +126,8 @@ def create_vllm_config( enable_prefix_caching=True, ) kv_transfer_config = KVTransferConfig( - kv_connector="NixlConnector", - kv_role="kv_both", + kv_connector=kv_connector, + kv_role=kv_role, enable_permute_local_kv=enable_permute_local_kv, kv_connector_extra_config=kv_connector_extra_config or {}, kv_load_failure_policy=kv_load_failure_policy, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index 45258e0d3..b49a01664 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -47,14 +47,17 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus +logger = init_logger(__name__) + try: from mooncake.engine import TransferEngine -except ImportError as e: - raise ImportError( +except ImportError: + logger.warning( "Please install mooncake by following the instructions at " "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " "to run VLLM with MooncakeTransferEngine." - ) from e + ) + TransferEngine = None if TYPE_CHECKING: from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -64,8 +67,6 @@ if TYPE_CHECKING: ReqId = str # Internal scheduler request ID TransferId = str # KV transfer coordination ID (shared by P/D) -logger = init_logger(__name__) - @dataclass(frozen=True) class TransferRegion: @@ -638,6 +639,9 @@ class MooncakeConnectorWorker: """Implementation of Worker side methods""" def __init__(self, vllm_config: VllmConfig, engine_id: str): + if TransferEngine is None: + logger.error("Mooncake is not available") + raise RuntimeError("Mooncake is not available") logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id) self.vllm_config = vllm_config @@ -721,9 +725,7 @@ class MooncakeConnectorWorker: # Start bootstrap server on global rank 0. if should_launch_bootstrap_server(vllm_config): _, port = get_mooncake_bootstrap_addr(vllm_config) - self.bootstrap_server = MooncakeBootstrapServer( - vllm_config, "0.0.0.0", port - ) + self.bootstrap_server = MooncakeBootstrapServer("0.0.0.0", port) self.bootstrap_server.start() if not self.is_kv_producer: @@ -778,7 +780,9 @@ class MooncakeConnectorWorker: if self.sender_loop.is_running(): self.sender_loop.call_soon_threadsafe(self.sender_loop.stop) self._sender_listener_t.join() - if should_launch_bootstrap_server(self.vllm_config): + if should_launch_bootstrap_server(self.vllm_config) and hasattr( + self, "bootstrap_server" + ): self.bootstrap_server.shutdown() if not self.is_kv_producer and self.receiver_loop.is_running(): self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py index d1a994670..2d158387f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py @@ -8,7 +8,6 @@ import uvicorn from fastapi import FastAPI, HTTPException from pydantic import BaseModel -from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import EngineId from vllm.logger import init_logger @@ -38,7 +37,7 @@ class MooncakeBootstrapServer: Prefiller workers register their connection info (IP, port, ranks) here. """ - def __init__(self, vllm_config: VllmConfig, host: str, port: int): + def __init__(self, host: str, port: int): self.workers: dict[int, EngineEntry] = {} self.host = host