[P/D] Mooncake: Add unit tests and minor fixes for mooncake connector (#36946)
Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
This commit is contained in:
756
tests/v1/kv_connector/unit/test_mooncake_connector.py
Normal file
756
tests/v1/kv_connector/unit/test_mooncake_connector.py
Normal file
@@ -0,0 +1,756 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector import (
|
||||
KVConnectorRole,
|
||||
MooncakeConnector,
|
||||
MooncakeConnectorMetadata,
|
||||
MooncakeXferMetadata,
|
||||
MooncakeXferResponse,
|
||||
MooncakeXferResponseStatus,
|
||||
PullReqMeta,
|
||||
SendBlockMeta,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import (
|
||||
MooncakeBootstrapServer,
|
||||
)
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
|
||||
class FakeMooncakeWrapper:
|
||||
"""Mock Mooncake TransferEngine for unit testing environments."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def initialize(self, local_hostname, metadata_server, protocol, device_name) -> int:
|
||||
return 0
|
||||
|
||||
def get_rpc_port(self) -> int:
|
||||
return 12345
|
||||
|
||||
def batch_transfer_sync_write(
|
||||
self, target_hostname, buffers, peer_buffer_addresses, lengths
|
||||
) -> int:
|
||||
return 0
|
||||
|
||||
def batch_register_memory(self, buffer_addresses, capacities) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def test_basic_interface():
|
||||
"""Unit test for basic MooncakeConnector interface functionality."""
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_consumer"
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
request_id = request.request_id
|
||||
request.kv_transfer_params.update(
|
||||
{
|
||||
"transfer_id": request_id,
|
||||
"remote_bootstrap_addr": 54321,
|
||||
}
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Remote Prefill, triggers NixlConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, MooncakeConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 1
|
||||
assert request_id in kv_connector_metadata.reqs_to_recv["my-engine-id"]
|
||||
req_meta = kv_connector_metadata.reqs_to_recv["my-engine-id"][request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids,
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id
|
||||
|
||||
|
||||
def test_prompt_less_than_block_size():
|
||||
"""Test that we can handle case where prompt is < block."""
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_consumer"
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Half of a block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
|
||||
|
||||
# Request will have 1 partial remote block.
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
num_remote_blocks=1,
|
||||
)
|
||||
request.kv_transfer_params.update(
|
||||
{
|
||||
"transfer_id": request.request_id,
|
||||
"remote_bootstrap_addr": 54321,
|
||||
}
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# This request will read async.
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, MooncakeConnectorMetadata)
|
||||
assert len(kv_connector_metadata.reqs_to_recv["my-engine-id"]) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bootstrap_server():
|
||||
"""Fixture to launch and cleanup a Mooncake Bootstrap HTTP Server."""
|
||||
|
||||
port = get_open_port()
|
||||
server = MooncakeBootstrapServer("127.0.0.1", port)
|
||||
server.start()
|
||||
yield server
|
||||
server.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bootstrap_server(bootstrap_server: MooncakeBootstrapServer):
|
||||
"""
|
||||
Tests the bootstrap server's api for worker registration and querying.
|
||||
|
||||
Validates DP/TP/PP rank indexing and error handling for duplicate registrations.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
|
||||
base_url = f"http://127.0.0.1:{bootstrap_server.port}"
|
||||
|
||||
# Query when empty
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{base_url}/query")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {}
|
||||
|
||||
# Register a worker
|
||||
payload1 = {
|
||||
"engine_id": "eng-1",
|
||||
"dp_rank": 0,
|
||||
"tp_rank": 0,
|
||||
"pp_rank": 0,
|
||||
"addr": "tcp://1.1.1.1:1111",
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(f"{base_url}/register", json=payload1)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
# Query after registration
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{base_url}/query")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "0" in data
|
||||
assert data["0"]["engine_id"] == "eng-1"
|
||||
assert data["0"]["worker_addr"]["0"]["0"] == "tcp://1.1.1.1:1111"
|
||||
|
||||
# Test failure: re-registering the same worker
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(f"{base_url}/register", json=payload1)
|
||||
assert response.status_code == 400
|
||||
assert "is already registered" in response.text
|
||||
|
||||
# Test failure: engine_id mismatch for same dp_rank
|
||||
payload3_fail = {
|
||||
"engine_id": "eng-2",
|
||||
"dp_rank": 0,
|
||||
"tp_rank": 1,
|
||||
"pp_rank": 0,
|
||||
"addr": "tcp://3.3.3.3:3333",
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(f"{base_url}/register", json=payload3_fail)
|
||||
assert response.status_code == 400
|
||||
assert "Engine ID mismatch" in response.text
|
||||
|
||||
|
||||
def test_scheduler_request_finished():
|
||||
"""
|
||||
Tests the scheduler-side logic when a request finishes.
|
||||
|
||||
Differentiates between 'Finished' (requires transfer)
|
||||
and 'Aborted' (immediate free).
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_producer"
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
scheduler_connector = scheduler.get_kv_connector().connector_scheduler
|
||||
|
||||
request = create_request(request_id=1, do_remote_decode=True)
|
||||
request.kv_transfer_params["transfer_id"] = request.request_id
|
||||
|
||||
# Case: Capped length (Successful prefill, need to send to decoder)
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
delay_free, _ = scheduler_connector.request_finished(request, block_ids=[10, 11])
|
||||
assert delay_free is True
|
||||
assert "id-1" in scheduler_connector._reqs_need_send
|
||||
assert scheduler_connector._reqs_need_send["id-1"][1] == [10, 11]
|
||||
|
||||
# Case: Aborted (No need to transfer, free blocks immediately)
|
||||
scheduler_connector._reqs_need_send.clear()
|
||||
request.status = RequestStatus.FINISHED_ABORTED
|
||||
delay_free, _ = scheduler_connector.request_finished(request, block_ids=[12])
|
||||
assert delay_free is False
|
||||
assert len(scheduler_connector._reqs_need_send) == 0
|
||||
assert "id-1" in scheduler_connector._reqs_not_processed
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_worker_dependencies():
|
||||
"""Helper to mock all distributed and network dependencies for Worker tests."""
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.TransferEngine",
|
||||
FakeMooncakeWrapper,
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.get_ip",
|
||||
return_value="127.0.0.1",
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.get_tensor_model_parallel_rank",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.get_tensor_model_parallel_world_size",
|
||||
return_value=1,
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.get_pp_group"
|
||||
) as mock_pp,
|
||||
patch("vllm.distributed.parallel_state.is_local_first_rank", return_value=True),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.should_launch_bootstrap_server",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.make_zmq_socket"
|
||||
) as mock_make_zmq,
|
||||
patch("httpx.AsyncClient") as mock_async_client,
|
||||
):
|
||||
# Mock PP group
|
||||
mock_pp_group = MagicMock()
|
||||
mock_pp_group.rank_in_group = 0
|
||||
mock_pp.return_value = mock_pp_group
|
||||
|
||||
# Mock ZMQ socket
|
||||
mock_socket_object = AsyncMock()
|
||||
mock_socket_object.setsockopt = MagicMock()
|
||||
mock_socket_ctx = MagicMock()
|
||||
mock_socket_ctx.__enter__.return_value = mock_socket_object
|
||||
mock_make_zmq.return_value = mock_socket_ctx
|
||||
|
||||
# Mock httpx client
|
||||
mock_http_client_instance = AsyncMock()
|
||||
mock_async_client.return_value = mock_http_client_instance
|
||||
|
||||
yield {
|
||||
"mock_make_zmq": mock_make_zmq,
|
||||
"mock_socket_object": mock_socket_object,
|
||||
"mock_async_client": mock_async_client,
|
||||
"mock_http_client": mock_http_client_instance,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.TransferEngine",
|
||||
FakeMooncakeWrapper,
|
||||
)
|
||||
async def test_kv_producer(monkeypatch):
|
||||
"""
|
||||
Simulates a Producer Worker (Prefiller) receiving a transfer request
|
||||
from a Consumer (Decoder).
|
||||
|
||||
Verifies memory offset calculation: ptr = base_addr + block_id * block_len.
|
||||
"""
|
||||
|
||||
monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5")
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_producer"
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
prefill_worker = prefill_connector.connector_worker
|
||||
prefill_worker.kv_caches_base_addr = [0x1000]
|
||||
block_len = 4096
|
||||
prefill_worker.block_len_per_layer = [block_len]
|
||||
|
||||
# Override loop to use current test loop
|
||||
origin_sender_loop = prefill_worker.sender_loop
|
||||
prefill_worker.sender_loop = asyncio.get_event_loop()
|
||||
|
||||
# A request is finished on Producer and ready to be sent.
|
||||
transfer_id = "xfer-req-1"
|
||||
send_meta = SendBlockMeta(
|
||||
p_req_id="p-req-1",
|
||||
transfer_id=transfer_id,
|
||||
local_block_ids=[10, 11],
|
||||
ready=asyncio.Event(),
|
||||
)
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.ready.set()
|
||||
|
||||
# Remote consumer request metadata
|
||||
xfer_meta = MooncakeXferMetadata(
|
||||
remote_hostname="consumer-host",
|
||||
remote_port=54321,
|
||||
remote_tp_size=1,
|
||||
remote_tp_rank=0,
|
||||
req_blocks={"d-req-1": (transfer_id, [20, 21])},
|
||||
kv_caches_base_addr=[0x2000],
|
||||
block_lens=[block_len],
|
||||
)
|
||||
|
||||
mock_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
mock_socket.send_multipart = AsyncMock()
|
||||
identity = b"consumer-id"
|
||||
|
||||
with patch.object(
|
||||
prefill_worker, "_send_blocks", return_value=0
|
||||
) as mock_send_blocks:
|
||||
# Normal case: 2 blocks to 2 blocks
|
||||
# Worker processes the consumer's request
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
# Verify transfer parameters are correct
|
||||
src_ptr = 0x1000 + 10 * block_len
|
||||
dst_ptr = 0x2000 + 20 * block_len
|
||||
length = 2 * block_len
|
||||
mock_send_blocks.assert_called_once_with(
|
||||
"consumer-host:54321", [src_ptr], [dst_ptr], [length]
|
||||
)
|
||||
mock_socket.send_multipart.assert_called_once()
|
||||
|
||||
# Verify the response sent back to the consumer
|
||||
sent_call = mock_socket.send_multipart.call_args[0][0]
|
||||
sent_identity, sent_payload = sent_call
|
||||
assert sent_identity == identity
|
||||
response = prefill_worker._xfer_resp_decoder.decode(sent_payload)
|
||||
assert response.status == MooncakeXferResponseStatus.FINISH
|
||||
assert response.ok_reqs == ["d-req-1"]
|
||||
|
||||
# Verify internal state cleanup
|
||||
assert transfer_id not in prefill_worker.reqs_need_send
|
||||
assert "p-req-1" in prefill_worker.finished_sending_reqs
|
||||
|
||||
# More cases:
|
||||
# Consumer only needs 1 block (less than P)
|
||||
mock_send_blocks.reset_mock()
|
||||
mock_socket.send_multipart.reset_mock()
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.sent = 0
|
||||
send_meta.ready.set()
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20])
|
||||
# Worker processes the consumer's request
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
# Verify transfer parameters are correct: 11 to 20
|
||||
src_ptr = 0x1000 + 11 * block_len
|
||||
dst_ptr = 0x2000 + 20 * block_len
|
||||
length = 1 * block_len
|
||||
mock_send_blocks.assert_called_once_with(
|
||||
"consumer-host:54321", [src_ptr], [dst_ptr], [length]
|
||||
)
|
||||
mock_socket.send_multipart.assert_called_once()
|
||||
|
||||
# Consumer needs 3 blocks (more than P, error case)
|
||||
mock_send_blocks.reset_mock()
|
||||
mock_socket.send_multipart.reset_mock()
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.sent = 0
|
||||
send_meta.ready.set()
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21, 22])
|
||||
# Worker processes the consumer's request
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
# This should not be called because error.
|
||||
mock_send_blocks.assert_not_called()
|
||||
mock_socket.send_multipart.assert_called_once()
|
||||
_, sent_payload = mock_socket.send_multipart.call_args[0][0]
|
||||
response = prefill_worker._xfer_resp_decoder.decode(sent_payload)
|
||||
assert response.err_msg == "P num blocks less than D"
|
||||
assert response.err_reqs == ["d-req-1"]
|
||||
|
||||
# Timeout
|
||||
mock_send_blocks.reset_mock()
|
||||
mock_socket.send_multipart.reset_mock()
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.sent = 0
|
||||
send_meta.ready.clear()
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21])
|
||||
# Worker processes the consumer's request
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
# This should not be called because timeout.
|
||||
mock_send_blocks.assert_not_called()
|
||||
mock_socket.send_multipart.assert_called_once()
|
||||
_, sent_payload = mock_socket.send_multipart.call_args[0][0]
|
||||
response = prefill_worker._xfer_resp_decoder.decode(sent_payload)
|
||||
assert response.err_msg == "Timeout waiting for P side ready."
|
||||
assert response.err_reqs == ["d-req-1"]
|
||||
|
||||
# Transfer error
|
||||
with patch.object(
|
||||
prefill_worker, "_send_blocks", return_value=123
|
||||
) as mock_send_blocks:
|
||||
mock_socket.send_multipart.reset_mock()
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.sent = 0
|
||||
send_meta.ready.set()
|
||||
xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21])
|
||||
# Worker processes the consumer's request
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
mock_send_blocks.assert_called_once()
|
||||
mock_socket.send_multipart.assert_called_once()
|
||||
_, sent_payload = mock_socket.send_multipart.call_args[0][0]
|
||||
response = prefill_worker._xfer_resp_decoder.decode(sent_payload)
|
||||
assert response.err_msg == "Mooncake transfer engine returned 123"
|
||||
assert response.err_reqs == ["d-req-1"]
|
||||
|
||||
# Clean up
|
||||
prefill_worker.sender_loop = origin_sender_loop
|
||||
prefill_worker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kv_consumuer(monkeypatch):
|
||||
"""
|
||||
Simulates a Consumer Worker (Decoder) initiating a pull from a Producer.
|
||||
|
||||
Verifies that MooncakeXferMetadata is correctly serialized and sent via ZMQ.
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_consumer"
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies() as mocks:
|
||||
decode_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
decode_worker = decode_connector.connector_worker
|
||||
decode_worker.kv_caches_base_addr = [0x1000]
|
||||
decode_worker.rpc_port = 54321
|
||||
|
||||
# A request to pull data arrives.
|
||||
pull_metas = {
|
||||
"d-req-1": PullReqMeta(
|
||||
d_req_id="d-req-1",
|
||||
transfer_id="xfer-req-1",
|
||||
local_block_ids=[100, 101],
|
||||
remote_engine_id="p-engine",
|
||||
remote_bootstrap_addr="http://bootstrap:33333",
|
||||
pull_tasks_count=1,
|
||||
)
|
||||
}
|
||||
decode_worker._remote_agents = {"p-engine": {0: {0: "tcp://producer:1234"}}}
|
||||
decode_worker._tp_size["p-engine"] = 1
|
||||
|
||||
# Mock the response from the producer.
|
||||
mock_response = MooncakeXferResponse(
|
||||
status=MooncakeXferResponseStatus.FINISH, ok_reqs=["d-req-1"]
|
||||
)
|
||||
encoded_response = decode_worker._encoder.encode(mock_response)
|
||||
mocks["mock_socket_object"].recv.return_value = encoded_response
|
||||
|
||||
# Trigger the receive logic.
|
||||
decode_worker.receive_kv("p-engine", pull_metas)
|
||||
await asyncio.sleep(1) # Allow async task to run
|
||||
|
||||
# Verify the metadata sent to the producer.
|
||||
mocks["mock_make_zmq"].assert_called_with(
|
||||
decode_worker.async_zmq_ctx,
|
||||
"tcp://producer:1234",
|
||||
zmq.DEALER,
|
||||
bind=False,
|
||||
linger=0,
|
||||
)
|
||||
sent_payload = mocks["mock_socket_object"].send.call_args[0][0]
|
||||
sent_meta = decode_worker._xfer_meta_decoder.decode(sent_payload)
|
||||
|
||||
assert sent_meta.remote_hostname == "127.0.0.1"
|
||||
assert sent_meta.remote_port == 54321
|
||||
assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [100, 101])
|
||||
|
||||
# Verify internal state is updated correctly.
|
||||
assert "d-req-1" in decode_worker.finished_recving_reqs
|
||||
|
||||
# Clean up
|
||||
decode_worker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_get_finished_timeout(monkeypatch):
|
||||
"""Tests the cleanup mechanism for requests."""
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_producer"
|
||||
)
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
prefill_worker = prefill_connector.connector_worker
|
||||
|
||||
# Add an expired request (expire_time is in the past).
|
||||
prefill_worker.reqs_need_send["tx-expired"] = SendBlockMeta(
|
||||
p_req_id="p-req-expired",
|
||||
transfer_id="tx-expired",
|
||||
local_block_ids=[1, 2],
|
||||
ready=MagicMock(),
|
||||
expire_time=time.perf_counter() - 100,
|
||||
)
|
||||
|
||||
# Add a non-expired request.
|
||||
prefill_worker.reqs_need_send["tx-active"] = SendBlockMeta(
|
||||
p_req_id="p-req-active",
|
||||
transfer_id="tx-active",
|
||||
local_block_ids=[3, 4],
|
||||
ready=MagicMock(),
|
||||
expire_time=time.perf_counter() + 100,
|
||||
)
|
||||
|
||||
finished_reqs = await prefill_worker.fetch_finished_sending_reqs()
|
||||
|
||||
assert "p-req-expired" in finished_reqs
|
||||
assert "p-req-active" not in finished_reqs
|
||||
assert "tx-expired" not in prefill_worker.reqs_need_send
|
||||
assert "tx-active" in prefill_worker.reqs_need_send
|
||||
|
||||
|
||||
def test_register_kv_caches():
|
||||
"""Tests the memory registration logic with the underlying Mooncake engine."""
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_consumer"
|
||||
)
|
||||
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
patch_worker_dependencies(),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Event"
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Thread"
|
||||
) as mock_thread,
|
||||
):
|
||||
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
worker = connector.connector_worker
|
||||
mock_thread.return_value.is_alive.return_value = False
|
||||
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
tensor1 = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
tensor2 = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {"layer0": tensor1, "layer1": tensor2}
|
||||
|
||||
with patch.object(
|
||||
worker.engine, "batch_register_memory", return_value=0
|
||||
) as mock_batch_register:
|
||||
connector.register_kv_caches(kv_caches)
|
||||
|
||||
mock_batch_register.assert_called_once()
|
||||
registered_ptrs, registered_lens = mock_batch_register.call_args[0]
|
||||
expected_ptrs = {
|
||||
tensor.data_ptr()
|
||||
for kv_pair in kv_caches.values()
|
||||
for tensor in kv_pair
|
||||
}
|
||||
assert set(registered_ptrs) == expected_ptrs
|
||||
assert set(registered_lens) == {tensor1[0].nbytes}
|
||||
|
||||
# Verify block_len_per_layer is set correctly.
|
||||
assert len(worker.block_len_per_layer) == len(registered_ptrs)
|
||||
for bl in worker.block_len_per_layer:
|
||||
assert bl == tensor1[0].nbytes // tensor1.shape[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake."
|
||||
"mooncake_connector.TransferEngine",
|
||||
FakeMooncakeWrapper,
|
||||
)
|
||||
@pytest.mark.parametrize("d_tp_size", [1, 4], ids=["p_tp2_d_tp1", "p_tp2_d_tp4"])
|
||||
async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
|
||||
"""
|
||||
Tests heterogeneous TP support in the producer transfer path.
|
||||
|
||||
Verifies correct pointer and offset calculation when producer TP=2
|
||||
sends to consumer with TP=1 (P>D) or TP=4 (P<D).
|
||||
|
||||
Parametrized cases:
|
||||
- P TP=2 > D TP=1: one D rank receives; dst_offset based on P rank
|
||||
- P TP=2 < D TP=4: two D ranks receive; src_offset based on D rank
|
||||
"""
|
||||
|
||||
P_TP_SIZE = 2
|
||||
P_TP_RANK = 0
|
||||
LOCAL_BLOCK_LEN = 4096
|
||||
|
||||
local_block_len = LOCAL_BLOCK_LEN
|
||||
remote_block_len = LOCAL_BLOCK_LEN * P_TP_SIZE // d_tp_size
|
||||
|
||||
monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5")
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_producer"
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
prefill_worker = prefill_connector.connector_worker
|
||||
|
||||
# Override TP rank/size to simulate P TP=2
|
||||
prefill_worker.tp_rank = P_TP_RANK
|
||||
prefill_worker.tp_size = P_TP_SIZE
|
||||
# Update shared dict so kv_topo sees correct TP size
|
||||
prefill_worker._tp_size[prefill_worker.engine_id] = P_TP_SIZE
|
||||
prefill_worker.kv_topo.tp_rank = P_TP_RANK
|
||||
|
||||
prefill_worker.kv_caches_base_addr = [0x1000]
|
||||
prefill_worker.block_len_per_layer = [local_block_len]
|
||||
|
||||
origin_sender_loop = prefill_worker.sender_loop
|
||||
prefill_worker.sender_loop = asyncio.get_event_loop()
|
||||
|
||||
transfer_id = "xfer-hetero-1"
|
||||
local_block_ids = [10, 11]
|
||||
send_meta = SendBlockMeta(
|
||||
p_req_id="p-req-h1",
|
||||
transfer_id=transfer_id,
|
||||
local_block_ids=local_block_ids,
|
||||
ready=asyncio.Event(),
|
||||
)
|
||||
prefill_worker.reqs_need_send[transfer_id] = send_meta
|
||||
send_meta.ready.set()
|
||||
|
||||
# Compute target D ranks using the production code path
|
||||
target_d_ranks = prefill_worker.kv_topo.get_target_remote_ranks(d_tp_size)
|
||||
|
||||
mock_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
mock_socket.send_multipart = AsyncMock()
|
||||
identity = b"consumer-hetero"
|
||||
|
||||
# Assign different remote block IDs per D rank
|
||||
d_rank_remote_blocks = {
|
||||
rank: [20 + i * 10, 21 + i * 10] for i, rank in enumerate(target_d_ranks)
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
prefill_worker, "_send_blocks", return_value=0
|
||||
) as mock_send_blocks:
|
||||
for d_rank in target_d_ranks:
|
||||
remote_block_ids = d_rank_remote_blocks[d_rank]
|
||||
xfer_meta = MooncakeXferMetadata(
|
||||
remote_hostname="consumer-host",
|
||||
remote_port=54321,
|
||||
remote_tp_size=d_tp_size,
|
||||
remote_tp_rank=d_rank,
|
||||
req_blocks={
|
||||
f"d-req-h1-r{d_rank}": (
|
||||
transfer_id,
|
||||
remote_block_ids,
|
||||
)
|
||||
},
|
||||
kv_caches_base_addr=[0x2000],
|
||||
block_lens=[remote_block_len],
|
||||
)
|
||||
|
||||
mock_send_blocks.reset_mock()
|
||||
mock_socket.send_multipart.reset_mock()
|
||||
|
||||
await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta)
|
||||
|
||||
# Verify _send_blocks was called
|
||||
mock_send_blocks.assert_called_once()
|
||||
call_args = mock_send_blocks.call_args[0]
|
||||
src_ptrs = call_args[1]
|
||||
dst_ptrs = call_args[2]
|
||||
lengths = call_args[3]
|
||||
|
||||
# Heterogeneous TP: blocks cannot be coalesced because
|
||||
# local and remote block_lens differ
|
||||
assert len(src_ptrs) == len(local_block_ids)
|
||||
assert len(dst_ptrs) == len(local_block_ids)
|
||||
assert len(lengths) == len(local_block_ids)
|
||||
|
||||
# Compute expected offsets based on TP ratio
|
||||
if d_tp_size <= P_TP_SIZE:
|
||||
tp_ratio = P_TP_SIZE // d_tp_size
|
||||
expected_src_off = 0
|
||||
expected_dst_off = (P_TP_RANK % tp_ratio) * local_block_len
|
||||
expected_xfer_len = local_block_len
|
||||
else:
|
||||
ratio_abs = d_tp_size // P_TP_SIZE
|
||||
expected_src_off = (d_rank % ratio_abs) * remote_block_len
|
||||
expected_dst_off = 0
|
||||
expected_xfer_len = remote_block_len
|
||||
|
||||
for idx, (lblk, rblk) in enumerate(
|
||||
zip(local_block_ids, remote_block_ids)
|
||||
):
|
||||
assert src_ptrs[idx] == (
|
||||
0x1000 + lblk * local_block_len + expected_src_off
|
||||
)
|
||||
assert dst_ptrs[idx] == (
|
||||
0x2000 + rblk * remote_block_len + expected_dst_off
|
||||
)
|
||||
assert lengths[idx] == expected_xfer_len
|
||||
|
||||
# Verify successful response sent back to consumer
|
||||
mock_socket.send_multipart.assert_called_once()
|
||||
_, sent_payload = mock_socket.send_multipart.call_args[0][0]
|
||||
response = prefill_worker._xfer_resp_decoder.decode(sent_payload)
|
||||
assert response.status == MooncakeXferResponseStatus.FINISH
|
||||
assert response.ok_reqs == [f"d-req-h1-r{d_rank}"]
|
||||
|
||||
# After serving all D ranks, the request should be complete
|
||||
assert transfer_id not in prefill_worker.reqs_need_send
|
||||
assert "p-req-h1" in prefill_worker.finished_sending_reqs
|
||||
|
||||
prefill_worker.sender_loop = origin_sender_loop
|
||||
prefill_worker.shutdown()
|
||||
@@ -100,6 +100,8 @@ def create_vllm_config(
|
||||
hf_overrides: dict[str, Any] | None = None,
|
||||
attention_backend: str | None = None,
|
||||
kv_load_failure_policy: Literal["recompute", "fail"] = "fail",
|
||||
kv_connector: str = "NixlConnector",
|
||||
kv_role: str = "kv_both",
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
model_config = ModelConfig(
|
||||
@@ -124,8 +126,8 @@ def create_vllm_config(
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector=kv_connector,
|
||||
kv_role=kv_role,
|
||||
enable_permute_local_kv=enable_permute_local_kv,
|
||||
kv_connector_extra_config=kv_connector_extra_config or {},
|
||||
kv_load_failure_policy=kv_load_failure_policy,
|
||||
|
||||
@@ -47,14 +47,17 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md "
|
||||
"to run VLLM with MooncakeTransferEngine."
|
||||
) from e
|
||||
)
|
||||
TransferEngine = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
@@ -64,8 +67,6 @@ if TYPE_CHECKING:
|
||||
ReqId = str # Internal scheduler request ID
|
||||
TransferId = str # KV transfer coordination ID (shared by P/D)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransferRegion:
|
||||
@@ -638,6 +639,9 @@ class MooncakeConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
if TransferEngine is None:
|
||||
logger.error("Mooncake is not available")
|
||||
raise RuntimeError("Mooncake is not available")
|
||||
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
@@ -721,9 +725,7 @@ class MooncakeConnectorWorker:
|
||||
# Start bootstrap server on global rank 0.
|
||||
if should_launch_bootstrap_server(vllm_config):
|
||||
_, port = get_mooncake_bootstrap_addr(vllm_config)
|
||||
self.bootstrap_server = MooncakeBootstrapServer(
|
||||
vllm_config, "0.0.0.0", port
|
||||
)
|
||||
self.bootstrap_server = MooncakeBootstrapServer("0.0.0.0", port)
|
||||
self.bootstrap_server.start()
|
||||
|
||||
if not self.is_kv_producer:
|
||||
@@ -778,7 +780,9 @@ class MooncakeConnectorWorker:
|
||||
if self.sender_loop.is_running():
|
||||
self.sender_loop.call_soon_threadsafe(self.sender_loop.stop)
|
||||
self._sender_listener_t.join()
|
||||
if should_launch_bootstrap_server(self.vllm_config):
|
||||
if should_launch_bootstrap_server(self.vllm_config) and hasattr(
|
||||
self, "bootstrap_server"
|
||||
):
|
||||
self.bootstrap_server.shutdown()
|
||||
if not self.is_kv_producer and self.receiver_loop.is_running():
|
||||
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
|
||||
|
||||
@@ -8,7 +8,6 @@ import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -38,7 +37,7 @@ class MooncakeBootstrapServer:
|
||||
Prefiller workers register their connection info (IP, port, ranks) here.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, host: str, port: int):
|
||||
def __init__(self, host: str, port: int):
|
||||
self.workers: dict[int, EngineEntry] = {}
|
||||
|
||||
self.host = host
|
||||
|
||||
Reference in New Issue
Block a user