[KVConnector] Support 3FS KVConnector (#37636)
Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com> Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
284
tests/v1/kv_connector/unit/test_hf3fs_client.py
Normal file
284
tests/v1/kv_connector/unit/test_hf3fs_client.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for resource management in hf3fs_client.py: constructor failure cleanup
|
||||
and idempotent close(). Tests use mock to replace real I/O operations
|
||||
(hf3fs_fuse.io, SharedMemory, os, CUDA).
|
||||
Requires hf3fs_fuse.io to be installed; skipped otherwise.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
HF3FS_AVAILABLE = True
|
||||
try:
|
||||
from hf3fs_fuse.io import ( # noqa: F401
|
||||
deregister_fd,
|
||||
extract_mount_point,
|
||||
make_ioring,
|
||||
make_iovec,
|
||||
register_fd,
|
||||
)
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client import (
|
||||
Hf3fsClient,
|
||||
)
|
||||
except Exception:
|
||||
HF3FS_AVAILABLE = False
|
||||
|
||||
requires_hf3fs = pytest.mark.skipif(
|
||||
not HF3FS_AVAILABLE,
|
||||
reason="hf3fs_fuse.io is not available on this machine",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeShm:
|
||||
"""Shared-memory stub matching the multiprocessing.shared_memory.SharedMemory
|
||||
interface used by Hf3fsClient:
|
||||
|
||||
Attributes accessed by the constructor:
|
||||
.buf – memoryview / buffer-protocol object consumed by torch.frombuffer
|
||||
Methods called during normal lifetime:
|
||||
.unlink() – called right after the iovec is set up
|
||||
.close() – called in _release_resources()
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = 1024):
|
||||
self._data = bytearray(size)
|
||||
self.buf = memoryview(self._data)
|
||||
self.closed = False
|
||||
self.close_call_count = 0
|
||||
self.unlink_call_count = 0
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
self.close_call_count += 1
|
||||
|
||||
def unlink(self):
|
||||
self.unlink_call_count += 1
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestHf3fsClientResourceManagement
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@requires_hf3fs
|
||||
class TestHf3fsClientResourceManagement:
|
||||
"""Tests for constructor failure cleanup and idempotent close()."""
|
||||
|
||||
_MOD = "vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helper: build a minimal Hf3fsClient bypassing all real I/O so that
|
||||
# we can fully control its internal state.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _make_client(self, tmp_path):
|
||||
"""Return a fully-mocked Hf3fsClient with controllable internals."""
|
||||
fake_shm_r = _FakeShm()
|
||||
fake_shm_w = _FakeShm()
|
||||
|
||||
patcher_list: list[Any] = [
|
||||
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
|
||||
patch(f"{self._MOD}.register_fd"),
|
||||
patch(f"{self._MOD}.deregister_fd"),
|
||||
patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"),
|
||||
patch(f"{self._MOD}.make_ioring", return_value=MagicMock()),
|
||||
patch(f"{self._MOD}.make_iovec", return_value=MagicMock()),
|
||||
patch(
|
||||
"multiprocessing.shared_memory.SharedMemory",
|
||||
side_effect=[fake_shm_r, fake_shm_w],
|
||||
),
|
||||
patch("os.open", return_value=99),
|
||||
patch("os.ftruncate"),
|
||||
patch("os.close"),
|
||||
patch("os.fsync"),
|
||||
patch("torch.cuda.Stream", return_value=MagicMock()),
|
||||
patch("torch.frombuffer", return_value=MagicMock()),
|
||||
patch("torch.empty", return_value=MagicMock()),
|
||||
]
|
||||
for p in patcher_list:
|
||||
p.start()
|
||||
|
||||
try:
|
||||
client = Hf3fsClient(
|
||||
path=str(tmp_path / "test.bin"),
|
||||
size=1024,
|
||||
bytes_per_page=256,
|
||||
entries=4,
|
||||
)
|
||||
finally:
|
||||
for p in patcher_list:
|
||||
p.stop()
|
||||
|
||||
# Manually point internal handles to our controllable fakes so that
|
||||
# assertions after close() can inspect them directly.
|
||||
client.shm_r = fake_shm_r
|
||||
client.shm_w = fake_shm_w
|
||||
client.file = 99
|
||||
return client, fake_shm_r, fake_shm_w
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# close() idempotency
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_close_idempotent_and_handles_cleared(self, tmp_path):
|
||||
"""Multiple close() calls must not raise; deregister_fd called exactly
|
||||
once, all handles set to None, shm.close() invoked."""
|
||||
client, shm_r, shm_w = self._make_client(tmp_path)
|
||||
|
||||
with (
|
||||
patch(f"{self._MOD}.deregister_fd") as mock_dereg,
|
||||
patch("os.close"),
|
||||
):
|
||||
client.close() # first close
|
||||
client.close() # second close — must be no-op
|
||||
client.close() # third close — must be no-op
|
||||
|
||||
assert client._closed is True
|
||||
assert mock_dereg.call_count == 1, (
|
||||
f"deregister_fd called {mock_dereg.call_count} times; expected 1"
|
||||
)
|
||||
for attr in ("iov_r", "iov_w", "ior_r", "ior_w", "shm_r", "shm_w", "file"):
|
||||
assert getattr(client, attr) is None, f"{attr} should be None after close()"
|
||||
assert shm_r.closed is True
|
||||
assert shm_w.closed is True
|
||||
|
||||
def test_flush_after_close_is_noop(self, tmp_path):
|
||||
"""flush() after close() must silently do nothing (no fsync call)."""
|
||||
client, _, _ = self._make_client(tmp_path)
|
||||
|
||||
with (
|
||||
patch(f"{self._MOD}.deregister_fd"),
|
||||
patch("os.close"),
|
||||
patch("os.fsync") as mock_fsync,
|
||||
):
|
||||
client.close()
|
||||
client.flush()
|
||||
|
||||
mock_fsync.assert_not_called()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Constructor failure leaves no leaked resources
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_constructor_failure_after_file_open_cleans_file(self, tmp_path):
|
||||
"""If the constructor raises after os.open(), the fd must be closed."""
|
||||
with (
|
||||
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
|
||||
patch(f"{self._MOD}.register_fd"),
|
||||
patch(f"{self._MOD}.deregister_fd"),
|
||||
patch(
|
||||
f"{self._MOD}.extract_mount_point",
|
||||
side_effect=RuntimeError("mount point not found"),
|
||||
),
|
||||
patch("os.open", return_value=55),
|
||||
patch("os.ftruncate"),
|
||||
patch("os.close") as mock_os_close,
|
||||
patch("torch.cuda.Stream", return_value=MagicMock()),
|
||||
pytest.raises(RuntimeError, match="mount point not found"),
|
||||
):
|
||||
Hf3fsClient(
|
||||
path=str(tmp_path / "fail.bin"),
|
||||
size=1024,
|
||||
bytes_per_page=256,
|
||||
entries=4,
|
||||
)
|
||||
|
||||
mock_os_close.assert_called_once_with(55)
|
||||
|
||||
def test_constructor_failure_after_shm_alloc_closes_shm(self, tmp_path):
|
||||
"""Constructor raises after SharedMemory creation → both shm objects closed."""
|
||||
fake_shm_r = _FakeShm()
|
||||
fake_shm_w = _FakeShm()
|
||||
|
||||
with (
|
||||
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
|
||||
patch(f"{self._MOD}.register_fd"),
|
||||
patch(f"{self._MOD}.deregister_fd"),
|
||||
patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"),
|
||||
patch(
|
||||
"multiprocessing.shared_memory.SharedMemory",
|
||||
side_effect=[fake_shm_r, fake_shm_w],
|
||||
),
|
||||
patch("os.open", return_value=66),
|
||||
patch("os.ftruncate"),
|
||||
patch("os.close"),
|
||||
patch("torch.frombuffer", return_value=MagicMock()),
|
||||
patch("torch.empty", return_value=MagicMock()),
|
||||
patch(
|
||||
f"{self._MOD}.make_ioring",
|
||||
side_effect=RuntimeError("ioring init failed"),
|
||||
),
|
||||
patch(f"{self._MOD}.make_iovec", return_value=MagicMock()),
|
||||
patch("torch.cuda.Stream", return_value=MagicMock()),
|
||||
pytest.raises(RuntimeError, match="ioring init failed"),
|
||||
):
|
||||
Hf3fsClient(
|
||||
path=str(tmp_path / "fail2.bin"),
|
||||
size=1024,
|
||||
bytes_per_page=256,
|
||||
entries=4,
|
||||
)
|
||||
|
||||
assert fake_shm_r.closed is True, (
|
||||
"shm_r was not closed after constructor failure"
|
||||
)
|
||||
assert fake_shm_w.closed is True, (
|
||||
"shm_w was not closed after constructor failure"
|
||||
)
|
||||
|
||||
def test_constructor_failure_does_not_close_unallocated_shm(self, tmp_path):
|
||||
"""Failure before SharedMemory is created must not raise AttributeError
|
||||
or TypeError from cleanup."""
|
||||
with (
|
||||
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
|
||||
patch(f"{self._MOD}.register_fd"),
|
||||
patch(f"{self._MOD}.deregister_fd"),
|
||||
patch(
|
||||
f"{self._MOD}.extract_mount_point",
|
||||
side_effect=RuntimeError("early failure"),
|
||||
),
|
||||
patch("os.open", return_value=77),
|
||||
patch("os.ftruncate"),
|
||||
patch("os.close"),
|
||||
patch("torch.cuda.Stream", return_value=MagicMock()),
|
||||
pytest.raises(RuntimeError, match="early failure"),
|
||||
):
|
||||
Hf3fsClient(
|
||||
path=str(tmp_path / "early_fail.bin"),
|
||||
size=1024,
|
||||
bytes_per_page=256,
|
||||
entries=4,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _release_resources on already-cleared state must be a no-op
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_release_resources_on_empty_state_is_safe(self, tmp_path):
|
||||
"""_release_resources() on a fully-cleared client must not raise."""
|
||||
client, _, _ = self._make_client(tmp_path)
|
||||
|
||||
with (
|
||||
patch(f"{self._MOD}.deregister_fd"),
|
||||
patch("os.close"),
|
||||
):
|
||||
client.close() # clears all handles
|
||||
|
||||
with (
|
||||
patch(f"{self._MOD}.deregister_fd") as mock_dereg2,
|
||||
patch("os.close") as mock_os_close2,
|
||||
):
|
||||
client._release_resources() # must not raise
|
||||
|
||||
mock_dereg2.assert_not_called()
|
||||
mock_os_close2.assert_not_called()
|
||||
230
tests/v1/kv_connector/unit/test_hf3fs_connector.py
Normal file
230
tests/v1/kv_connector/unit/test_hf3fs_connector.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for HF3FS KV Connector high-level components:
|
||||
- TestHf3fsMockClient : file-backed mock client I/O correctness
|
||||
- TestHF3FSKVConnectorStats: metric collection, aggregation, serialisation
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_connector import (
|
||||
HF3FSKVConnectorStats,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.hf3fs_mock_client import (
|
||||
Hf3fsClient as MockHf3fsClient,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf3fs_stats():
|
||||
"""Fresh HF3FSKVConnectorStats instance."""
|
||||
return HF3FSKVConnectorStats()
|
||||
|
||||
|
||||
def _make_cuda_event():
|
||||
"""Return a real CUDA event when available, otherwise a MagicMock."""
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.Event()
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestHf3fsMockClient
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestHf3fsMockClient:
|
||||
"""Tests for hf3fs_mock_client.Hf3fsClient (file-backend mock)."""
|
||||
|
||||
def test_init_creates_file(self, tmp_path):
|
||||
"""Initializing the client should create the backing file."""
|
||||
path = str(tmp_path / "test_file")
|
||||
client = MockHf3fsClient(path=path, size=4096, bytes_per_page=512, entries=4)
|
||||
assert os.path.exists(path), "Backing file should be created on init"
|
||||
assert os.path.getsize(path) == 4096
|
||||
client.close()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, bytes_per_page",
|
||||
[
|
||||
(torch.float32, 512),
|
||||
(torch.float16, 256),
|
||||
(torch.bfloat16, 256),
|
||||
],
|
||||
ids=["float32", "float16", "bfloat16"],
|
||||
)
|
||||
def test_batch_write_and_read_dtype(self, tmp_path, dtype, bytes_per_page):
|
||||
"""Write a tensor of the given dtype and verify round-trip correctness."""
|
||||
path = str(tmp_path / f"rw_{dtype}")
|
||||
client = MockHf3fsClient(
|
||||
path=path, size=bytes_per_page * 8, bytes_per_page=bytes_per_page, entries=4
|
||||
)
|
||||
elem_size = torch.tensor([], dtype=dtype).element_size()
|
||||
numel = bytes_per_page // elem_size
|
||||
tensor_write = torch.arange(numel, dtype=dtype)
|
||||
event = _make_cuda_event()
|
||||
|
||||
results = client.batch_write([0], [tensor_write], event)
|
||||
assert results == [bytes_per_page], f"Write should succeed, got {results}"
|
||||
|
||||
tensor_read = torch.zeros(numel, dtype=dtype)
|
||||
results = client.batch_read([0], [tensor_read])
|
||||
assert results == [bytes_per_page], f"Read should succeed, got {results}"
|
||||
assert torch.equal(tensor_write, tensor_read), (
|
||||
"Read tensor should match written tensor"
|
||||
)
|
||||
client.close()
|
||||
|
||||
def test_batch_read_empty_file_returns_error(self, tmp_path):
|
||||
"""Reading out-of-bounds offset should return -1."""
|
||||
bytes_per_page = 128
|
||||
size = bytes_per_page * 4
|
||||
path = str(tmp_path / "empty_read")
|
||||
client = MockHf3fsClient(
|
||||
path=path, size=size, bytes_per_page=bytes_per_page, entries=4
|
||||
)
|
||||
numel = bytes_per_page // 4
|
||||
tensor_read = torch.zeros(numel, dtype=torch.float32)
|
||||
results = client.batch_read([size], [tensor_read]) # offset == size => OOB
|
||||
assert results[0] == -1, "Out-of-bounds read should return -1"
|
||||
client.close()
|
||||
|
||||
def test_batch_write_out_of_bounds_returns_error(self, tmp_path):
|
||||
"""Writing at an offset beyond file size should return -1."""
|
||||
bytes_per_page = 128
|
||||
size = bytes_per_page * 4
|
||||
path = str(tmp_path / "oob_write")
|
||||
client = MockHf3fsClient(
|
||||
path=path, size=size, bytes_per_page=bytes_per_page, entries=4
|
||||
)
|
||||
numel = bytes_per_page // 4
|
||||
tensor = torch.ones(numel, dtype=torch.float32)
|
||||
event = _make_cuda_event()
|
||||
results = client.batch_write([size], [tensor], event) # OOB offset
|
||||
assert results[0] == -1, "Out-of-bounds write should return -1"
|
||||
client.close()
|
||||
|
||||
def test_multiple_tensors_rw(self, tmp_path):
|
||||
"""Write multiple tensors at different offsets, then read all back."""
|
||||
bytes_per_page = 128
|
||||
n = 4
|
||||
path = str(tmp_path / "multi_rw")
|
||||
client = MockHf3fsClient(
|
||||
path=path,
|
||||
size=bytes_per_page * n * 2,
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=8,
|
||||
)
|
||||
tensors_write = [
|
||||
torch.full((bytes_per_page // 4,), float(i), dtype=torch.float32)
|
||||
for i in range(n)
|
||||
]
|
||||
offsets = [i * bytes_per_page for i in range(n)]
|
||||
event = _make_cuda_event()
|
||||
|
||||
results = client.batch_write(offsets, tensors_write, event)
|
||||
assert all(r == bytes_per_page for r in results)
|
||||
|
||||
tensors_read = [
|
||||
torch.zeros(bytes_per_page // 4, dtype=torch.float32) for _ in range(n)
|
||||
]
|
||||
results = client.batch_read(offsets, tensors_read)
|
||||
assert all(r == bytes_per_page for r in results)
|
||||
|
||||
for i, (tw, tr) in enumerate(zip(tensors_write, tensors_read)):
|
||||
assert torch.allclose(tw, tr), f"Tensor {i} mismatch after round-trip"
|
||||
client.close()
|
||||
|
||||
def test_flush_and_close_no_error(self, tmp_path):
|
||||
"""flush() and close() should not raise exceptions."""
|
||||
path = str(tmp_path / "flush_close")
|
||||
client = MockHf3fsClient(path=path, size=1024, bytes_per_page=128, entries=4)
|
||||
client.flush()
|
||||
client.close()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestHF3FSKVConnectorStats
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestHF3FSKVConnectorStats:
|
||||
"""Tests for HF3FSKVConnectorStats metric collection and aggregation."""
|
||||
|
||||
def test_initial_is_empty(self, hf3fs_stats):
|
||||
"""Fresh stats object should report is_empty() == True."""
|
||||
assert hf3fs_stats.is_empty() is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_type, duration_key",
|
||||
[
|
||||
("Saved", "save_duration"),
|
||||
("Loaded", "load_duration"),
|
||||
],
|
||||
ids=["save", "load"],
|
||||
)
|
||||
def test_record_success_duration(self, hf3fs_stats, task_type, duration_key):
|
||||
"""Recording a successful task should update duration list and total count."""
|
||||
hf3fs_stats.record_success_task_duration(task_type, 0.5)
|
||||
assert not hf3fs_stats.is_empty()
|
||||
assert len(hf3fs_stats.data[duration_key]) == 1
|
||||
assert hf3fs_stats.data[duration_key][0] == pytest.approx(0.5)
|
||||
assert hf3fs_stats.data["num_transfer_task"] == 1
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_type, failed_key",
|
||||
[
|
||||
("Saved", "num_failed_save"),
|
||||
("Loaded", "num_failed_load"),
|
||||
],
|
||||
ids=["save", "load"],
|
||||
)
|
||||
def test_record_failed_task(self, hf3fs_stats, task_type, failed_key):
|
||||
"""Recording a failed task should increment the corresponding counter."""
|
||||
hf3fs_stats.record_failed_task_count(task_type)
|
||||
assert hf3fs_stats.data[failed_key] == 1
|
||||
assert hf3fs_stats.data["num_transfer_task"] == 1
|
||||
|
||||
def test_aggregate_two_stats(self):
|
||||
"""aggregate() should merge save/load duration lists and sum counters."""
|
||||
stats1 = HF3FSKVConnectorStats()
|
||||
stats1.record_success_task_duration("Saved", 0.1)
|
||||
stats1.record_success_task_duration("Loaded", 0.2)
|
||||
|
||||
stats2 = HF3FSKVConnectorStats()
|
||||
stats2.record_success_task_duration("Saved", 0.3)
|
||||
stats2.record_failed_task_count("Loaded")
|
||||
|
||||
stats1.aggregate(stats2)
|
||||
assert stats1.data["save_duration"] == pytest.approx([0.1, 0.3])
|
||||
assert stats1.data["load_duration"] == pytest.approx([0.2])
|
||||
assert stats1.data["num_failed_load"] == 1
|
||||
assert stats1.data["num_transfer_task"] == 4
|
||||
|
||||
def test_reduce_with_data(self):
|
||||
"""reduce() computes correct averages when data is present."""
|
||||
stats = HF3FSKVConnectorStats()
|
||||
stats.record_success_task_duration("Saved", 1.0)
|
||||
stats.record_success_task_duration("Saved", 3.0)
|
||||
result = stats.reduce()
|
||||
assert result["Num save task success"] == pytest.approx(2.0, rel=0.01)
|
||||
assert result["Num save task failed"] == pytest.approx(0.0, rel=0.01)
|
||||
assert result["Avg save duration (ms)"] == pytest.approx(2000.0, rel=0.01)
|
||||
|
||||
def test_clone_and_reset(self, hf3fs_stats):
|
||||
"""clone_and_reset() returns a copy with data and resets the original."""
|
||||
hf3fs_stats.record_success_task_duration("Saved", 0.7)
|
||||
hf3fs_stats.record_success_task_duration("Loaded", 0.4)
|
||||
|
||||
clone = hf3fs_stats.clone_and_reset()
|
||||
assert clone.data["num_transfer_task"] == 2
|
||||
assert hf3fs_stats.is_empty()
|
||||
193
tests/v1/kv_connector/unit/test_hf3fs_metadata_server.py
Normal file
193
tests/v1/kv_connector/unit/test_hf3fs_metadata_server.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for HF3FS metadata server data structures and allocation logic:
|
||||
- RankFileMetadata : page allocation / release primitives
|
||||
- KeyMetadata : per-key rank-page tracking and completion detection
|
||||
- GlobalMetadataState : coordinated allocation with cache-hit semantics
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_metadata_server import (
|
||||
GlobalMetadataState,
|
||||
KeyMetadata,
|
||||
RankFileMetadata,
|
||||
)
|
||||
|
||||
# ===========================================================================
|
||||
# TestRankFileMetadata
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestRankFileMetadata:
|
||||
"""Unit tests for RankFileMetadata page allocation primitives."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"alloc_count, expected_pages",
|
||||
[(3, 3), (5, 0)],
|
||||
ids=["alloc_partial", "alloc_exceeds"],
|
||||
)
|
||||
def test_allocate_pages(self, alloc_count, expected_pages):
|
||||
"""allocate_pages returns correct pages or empty list when insufficient."""
|
||||
rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3)))
|
||||
pages = rank_meta.allocate_pages(alloc_count)
|
||||
assert len(pages) == expected_pages
|
||||
if expected_pages > 0:
|
||||
rank_meta.release_pages(pages)
|
||||
assert rank_meta.get_free_page_count() == 3
|
||||
|
||||
def test_release_pages_restores_count(self):
|
||||
"""Releasing allocated pages returns them to the free pool."""
|
||||
rank_meta = RankFileMetadata(rank_id=0, num_pages=4, free_pages=list(range(4)))
|
||||
pages = rank_meta.allocate_pages(2)
|
||||
assert rank_meta.get_free_page_count() == 2
|
||||
rank_meta.release_pages(pages)
|
||||
assert rank_meta.get_free_page_count() == 4
|
||||
|
||||
def test_release_pages_no_duplicates(self):
|
||||
"""Releasing the same page twice must not create duplicates."""
|
||||
rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3)))
|
||||
rank_meta.allocate_pages(1) # takes page 0
|
||||
rank_meta.release_pages([0])
|
||||
rank_meta.release_pages([0]) # second release of the same page
|
||||
assert rank_meta.get_free_page_count() == 3
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestKeyMetadata
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestKeyMetadata:
|
||||
"""Unit tests for KeyMetadata completion tracking."""
|
||||
|
||||
def test_is_complete_false_until_all_ranks(self):
|
||||
"""is_complete() returns True only when all ranks confirmed."""
|
||||
key_meta = KeyMetadata(key="k", rank_to_page={}, tp_world_size=2)
|
||||
assert key_meta.is_complete() is False
|
||||
key_meta.add_rank_page(0, 5)
|
||||
assert key_meta.is_complete() is False
|
||||
key_meta.add_rank_page(1, 10)
|
||||
assert key_meta.is_complete() is True
|
||||
|
||||
def test_get_rank_page_returns_none_for_missing_rank(self):
|
||||
"""get_rank_page() returns None when the rank has no entry."""
|
||||
key_meta = KeyMetadata(key="k", rank_to_page={0: 3}, tp_world_size=2)
|
||||
assert key_meta.get_rank_page(0) == 3
|
||||
assert key_meta.get_rank_page(1) is None
|
||||
|
||||
def test_get_all_pages(self):
|
||||
"""get_all_pages() returns all (rank, page) pairs."""
|
||||
key_meta = KeyMetadata(key="k", rank_to_page={0: 1, 1: 2}, tp_world_size=2)
|
||||
pairs = key_meta.get_all_pages()
|
||||
assert set(pairs) == {(0, 1), (1, 2)}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TestGlobalMetadataStateAllocation
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGlobalMetadataStateAllocation:
|
||||
"""Tests for GlobalMetadataState allocation and cache-hit semantics."""
|
||||
|
||||
def test_uninitialized_rank_raises_on_allocate(self):
|
||||
"""allocate_pages_for_keys raises ValueError for unknown rank."""
|
||||
state = GlobalMetadataState()
|
||||
with pytest.raises((ValueError, Exception)):
|
||||
state.allocate_pages_for_keys(99, [("key", "")])
|
||||
|
||||
def test_uninitialized_rank_raises_on_get_locations(self):
|
||||
"""get_key_locations raises ValueError for unknown rank."""
|
||||
state = GlobalMetadataState()
|
||||
with pytest.raises((ValueError, Exception)):
|
||||
state.get_key_locations(99, ["any_key"])
|
||||
|
||||
def test_basic_allocation_and_confirm(self):
|
||||
"""Allocating a page and confirming it marks the key as complete."""
|
||||
state = GlobalMetadataState()
|
||||
state.initialize_rank(0, 4)
|
||||
|
||||
results = state.allocate_pages_for_keys(0, [("K", "")])
|
||||
assert results["K"] >= 0
|
||||
|
||||
state.confirm_write_for_keys(0, [("K", results["K"])])
|
||||
assert state.batch_key_exists(["K"]) == [True]
|
||||
locations = state.get_key_locations(0, ["K"])
|
||||
assert locations == [results["K"]]
|
||||
|
||||
def test_allocate_pages_cache_hit_does_not_leak_pages(self):
|
||||
"""Cache-hit key must not consume a page from the free pool;
|
||||
the pre-allocated slot must be returned before reusing the existing page.
|
||||
"""
|
||||
state = GlobalMetadataState()
|
||||
state.initialize_rank(0, 5) # 5 free pages: [0,1,2,3,4]
|
||||
|
||||
# Simulate a key that has already been fully written and confirmed.
|
||||
state.key_metadata["K_cached"] = KeyMetadata(
|
||||
key="K_cached", rank_to_page={0: 2}, tp_world_size=1
|
||||
)
|
||||
|
||||
free_before = state.rank_metadata[0].get_free_page_count() # 5
|
||||
|
||||
results = state.allocate_pages_for_keys(0, [("K_cached", ""), ("K_new", "")])
|
||||
|
||||
free_after = state.rank_metadata[0].get_free_page_count()
|
||||
|
||||
# Cache-hit key must reuse its existing page.
|
||||
assert results["K_cached"] == 2, (
|
||||
f"Cache-hit key should reuse page 2, got {results['K_cached']}"
|
||||
)
|
||||
# New key must receive a valid page.
|
||||
assert results["K_new"] >= 0, (
|
||||
f"New key should get a valid page, got {results['K_new']}"
|
||||
)
|
||||
# Exactly one page consumed from the free pool.
|
||||
assert free_before - free_after == 1, (
|
||||
f"Expected 1 page consumed, got delta={free_before - free_after}"
|
||||
)
|
||||
|
||||
def test_allocate_pages_all_cache_hits_frees_all_slots(self):
|
||||
"""When every key in the batch is a cache hit, no pages are consumed."""
|
||||
state = GlobalMetadataState()
|
||||
state.initialize_rank(0, 5)
|
||||
|
||||
for key, page in (("K1", 0), ("K2", 1)):
|
||||
state.key_metadata[key] = KeyMetadata(
|
||||
key=key, rank_to_page={0: page}, tp_world_size=1
|
||||
)
|
||||
|
||||
free_before = state.rank_metadata[0].get_free_page_count()
|
||||
results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")])
|
||||
free_after = state.rank_metadata[0].get_free_page_count()
|
||||
|
||||
assert results["K1"] == 0
|
||||
assert results["K2"] == 1
|
||||
assert free_after == free_before, (
|
||||
f"All-cache-hit batch must not consume free pages; "
|
||||
f"before={free_before}, after={free_after}"
|
||||
)
|
||||
|
||||
def test_allocate_returns_minus_one_when_pool_exhausted(self):
|
||||
"""If the free pool is exhausted, all new keys receive -1."""
|
||||
state = GlobalMetadataState()
|
||||
state.initialize_rank(0, 1) # only 1 free page
|
||||
|
||||
results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")])
|
||||
# allocate_pages uses all-or-nothing: 2 needed but only 1 available → []
|
||||
assert all(v == -1 for v in results.values()), f"Expected all -1, got {results}"
|
||||
|
||||
def test_confirm_write_releases_pages(self):
|
||||
"""confirm_write_for_keys with pages_to_release returns them to pool."""
|
||||
state = GlobalMetadataState()
|
||||
state.initialize_rank(0, 3)
|
||||
|
||||
results = state.allocate_pages_for_keys(0, [("K", "")])
|
||||
page = results["K"]
|
||||
free_after_alloc = state.rank_metadata[0].get_free_page_count()
|
||||
|
||||
state.confirm_write_for_keys(0, [("K", page)], pages_to_release=[page])
|
||||
free_after_release = state.rank_metadata[0].get_free_page_count()
|
||||
|
||||
assert free_after_release == free_after_alloc + 1
|
||||
Reference in New Issue
Block a user