[KVConnector] Support 3FS KVConnector (#37636)

Signed-off-by: wuchenxin <wuchenxin.wcx@alibaba-inc.com>
Signed-off-by: ibifrost <47308427+ibifrost@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
ibifrost
2026-04-07 23:46:00 +08:00
committed by GitHub
parent 98e1a43af7
commit 96b5004b71
14 changed files with 3353 additions and 2 deletions

View File

@@ -0,0 +1,284 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for resource management in hf3fs_client.py: constructor failure cleanup
and idempotent close(). Tests use mock to replace real I/O operations
(hf3fs_fuse.io, SharedMemory, os, CUDA).
Requires hf3fs_fuse.io to be installed; skipped otherwise.
"""
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
HF3FS_AVAILABLE = True
try:
from hf3fs_fuse.io import ( # noqa: F401
deregister_fd,
extract_mount_point,
make_ioring,
make_iovec,
register_fd,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client import (
Hf3fsClient,
)
except Exception:
HF3FS_AVAILABLE = False
requires_hf3fs = pytest.mark.skipif(
not HF3FS_AVAILABLE,
reason="hf3fs_fuse.io is not available on this machine",
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeShm:
"""Shared-memory stub matching the multiprocessing.shared_memory.SharedMemory
interface used by Hf3fsClient:
Attributes accessed by the constructor:
.buf memoryview / buffer-protocol object consumed by torch.frombuffer
Methods called during normal lifetime:
.unlink() called right after the iovec is set up
.close() called in _release_resources()
"""
def __init__(self, size: int = 1024):
self._data = bytearray(size)
self.buf = memoryview(self._data)
self.closed = False
self.close_call_count = 0
self.unlink_call_count = 0
def close(self):
self.closed = True
self.close_call_count += 1
def unlink(self):
self.unlink_call_count += 1
# ===========================================================================
# TestHf3fsClientResourceManagement
# ===========================================================================
@requires_hf3fs
class TestHf3fsClientResourceManagement:
"""Tests for constructor failure cleanup and idempotent close()."""
_MOD = "vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client"
# ------------------------------------------------------------------
# Helper: build a minimal Hf3fsClient bypassing all real I/O so that
# we can fully control its internal state.
# ------------------------------------------------------------------
def _make_client(self, tmp_path):
"""Return a fully-mocked Hf3fsClient with controllable internals."""
fake_shm_r = _FakeShm()
fake_shm_w = _FakeShm()
patcher_list: list[Any] = [
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"),
patch(f"{self._MOD}.make_ioring", return_value=MagicMock()),
patch(f"{self._MOD}.make_iovec", return_value=MagicMock()),
patch(
"multiprocessing.shared_memory.SharedMemory",
side_effect=[fake_shm_r, fake_shm_w],
),
patch("os.open", return_value=99),
patch("os.ftruncate"),
patch("os.close"),
patch("os.fsync"),
patch("torch.cuda.Stream", return_value=MagicMock()),
patch("torch.frombuffer", return_value=MagicMock()),
patch("torch.empty", return_value=MagicMock()),
]
for p in patcher_list:
p.start()
try:
client = Hf3fsClient(
path=str(tmp_path / "test.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
finally:
for p in patcher_list:
p.stop()
# Manually point internal handles to our controllable fakes so that
# assertions after close() can inspect them directly.
client.shm_r = fake_shm_r
client.shm_w = fake_shm_w
client.file = 99
return client, fake_shm_r, fake_shm_w
# ------------------------------------------------------------------
# close() idempotency
# ------------------------------------------------------------------
def test_close_idempotent_and_handles_cleared(self, tmp_path):
"""Multiple close() calls must not raise; deregister_fd called exactly
once, all handles set to None, shm.close() invoked."""
client, shm_r, shm_w = self._make_client(tmp_path)
with (
patch(f"{self._MOD}.deregister_fd") as mock_dereg,
patch("os.close"),
):
client.close() # first close
client.close() # second close — must be no-op
client.close() # third close — must be no-op
assert client._closed is True
assert mock_dereg.call_count == 1, (
f"deregister_fd called {mock_dereg.call_count} times; expected 1"
)
for attr in ("iov_r", "iov_w", "ior_r", "ior_w", "shm_r", "shm_w", "file"):
assert getattr(client, attr) is None, f"{attr} should be None after close()"
assert shm_r.closed is True
assert shm_w.closed is True
def test_flush_after_close_is_noop(self, tmp_path):
"""flush() after close() must silently do nothing (no fsync call)."""
client, _, _ = self._make_client(tmp_path)
with (
patch(f"{self._MOD}.deregister_fd"),
patch("os.close"),
patch("os.fsync") as mock_fsync,
):
client.close()
client.flush()
mock_fsync.assert_not_called()
# ------------------------------------------------------------------
# Constructor failure leaves no leaked resources
# ------------------------------------------------------------------
def test_constructor_failure_after_file_open_cleans_file(self, tmp_path):
"""If the constructor raises after os.open(), the fd must be closed."""
with (
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(
f"{self._MOD}.extract_mount_point",
side_effect=RuntimeError("mount point not found"),
),
patch("os.open", return_value=55),
patch("os.ftruncate"),
patch("os.close") as mock_os_close,
patch("torch.cuda.Stream", return_value=MagicMock()),
pytest.raises(RuntimeError, match="mount point not found"),
):
Hf3fsClient(
path=str(tmp_path / "fail.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
mock_os_close.assert_called_once_with(55)
def test_constructor_failure_after_shm_alloc_closes_shm(self, tmp_path):
"""Constructor raises after SharedMemory creation → both shm objects closed."""
fake_shm_r = _FakeShm()
fake_shm_w = _FakeShm()
with (
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"),
patch(
"multiprocessing.shared_memory.SharedMemory",
side_effect=[fake_shm_r, fake_shm_w],
),
patch("os.open", return_value=66),
patch("os.ftruncate"),
patch("os.close"),
patch("torch.frombuffer", return_value=MagicMock()),
patch("torch.empty", return_value=MagicMock()),
patch(
f"{self._MOD}.make_ioring",
side_effect=RuntimeError("ioring init failed"),
),
patch(f"{self._MOD}.make_iovec", return_value=MagicMock()),
patch("torch.cuda.Stream", return_value=MagicMock()),
pytest.raises(RuntimeError, match="ioring init failed"),
):
Hf3fsClient(
path=str(tmp_path / "fail2.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
assert fake_shm_r.closed is True, (
"shm_r was not closed after constructor failure"
)
assert fake_shm_w.closed is True, (
"shm_w was not closed after constructor failure"
)
def test_constructor_failure_does_not_close_unallocated_shm(self, tmp_path):
"""Failure before SharedMemory is created must not raise AttributeError
or TypeError from cleanup."""
with (
patch(f"{self._MOD}.HF3FS_AVAILABLE", True),
patch(f"{self._MOD}.register_fd"),
patch(f"{self._MOD}.deregister_fd"),
patch(
f"{self._MOD}.extract_mount_point",
side_effect=RuntimeError("early failure"),
),
patch("os.open", return_value=77),
patch("os.ftruncate"),
patch("os.close"),
patch("torch.cuda.Stream", return_value=MagicMock()),
pytest.raises(RuntimeError, match="early failure"),
):
Hf3fsClient(
path=str(tmp_path / "early_fail.bin"),
size=1024,
bytes_per_page=256,
entries=4,
)
# ------------------------------------------------------------------
# _release_resources on already-cleared state must be a no-op
# ------------------------------------------------------------------
def test_release_resources_on_empty_state_is_safe(self, tmp_path):
"""_release_resources() on a fully-cleared client must not raise."""
client, _, _ = self._make_client(tmp_path)
with (
patch(f"{self._MOD}.deregister_fd"),
patch("os.close"),
):
client.close() # clears all handles
with (
patch(f"{self._MOD}.deregister_fd") as mock_dereg2,
patch("os.close") as mock_os_close2,
):
client._release_resources() # must not raise
mock_dereg2.assert_not_called()
mock_os_close2.assert_not_called()

View File

@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for HF3FS KV Connector high-level components:
- TestHf3fsMockClient : file-backed mock client I/O correctness
- TestHF3FSKVConnectorStats: metric collection, aggregation, serialisation
"""
import os
from unittest.mock import MagicMock
import pytest
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_connector import (
HF3FSKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.hf3fs_mock_client import (
Hf3fsClient as MockHf3fsClient,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@pytest.fixture
def hf3fs_stats():
"""Fresh HF3FSKVConnectorStats instance."""
return HF3FSKVConnectorStats()
def _make_cuda_event():
"""Return a real CUDA event when available, otherwise a MagicMock."""
if torch.cuda.is_available():
return torch.cuda.Event()
return MagicMock()
# ===========================================================================
# TestHf3fsMockClient
# ===========================================================================
class TestHf3fsMockClient:
"""Tests for hf3fs_mock_client.Hf3fsClient (file-backend mock)."""
def test_init_creates_file(self, tmp_path):
"""Initializing the client should create the backing file."""
path = str(tmp_path / "test_file")
client = MockHf3fsClient(path=path, size=4096, bytes_per_page=512, entries=4)
assert os.path.exists(path), "Backing file should be created on init"
assert os.path.getsize(path) == 4096
client.close()
@pytest.mark.parametrize(
"dtype, bytes_per_page",
[
(torch.float32, 512),
(torch.float16, 256),
(torch.bfloat16, 256),
],
ids=["float32", "float16", "bfloat16"],
)
def test_batch_write_and_read_dtype(self, tmp_path, dtype, bytes_per_page):
"""Write a tensor of the given dtype and verify round-trip correctness."""
path = str(tmp_path / f"rw_{dtype}")
client = MockHf3fsClient(
path=path, size=bytes_per_page * 8, bytes_per_page=bytes_per_page, entries=4
)
elem_size = torch.tensor([], dtype=dtype).element_size()
numel = bytes_per_page // elem_size
tensor_write = torch.arange(numel, dtype=dtype)
event = _make_cuda_event()
results = client.batch_write([0], [tensor_write], event)
assert results == [bytes_per_page], f"Write should succeed, got {results}"
tensor_read = torch.zeros(numel, dtype=dtype)
results = client.batch_read([0], [tensor_read])
assert results == [bytes_per_page], f"Read should succeed, got {results}"
assert torch.equal(tensor_write, tensor_read), (
"Read tensor should match written tensor"
)
client.close()
def test_batch_read_empty_file_returns_error(self, tmp_path):
"""Reading out-of-bounds offset should return -1."""
bytes_per_page = 128
size = bytes_per_page * 4
path = str(tmp_path / "empty_read")
client = MockHf3fsClient(
path=path, size=size, bytes_per_page=bytes_per_page, entries=4
)
numel = bytes_per_page // 4
tensor_read = torch.zeros(numel, dtype=torch.float32)
results = client.batch_read([size], [tensor_read]) # offset == size => OOB
assert results[0] == -1, "Out-of-bounds read should return -1"
client.close()
def test_batch_write_out_of_bounds_returns_error(self, tmp_path):
"""Writing at an offset beyond file size should return -1."""
bytes_per_page = 128
size = bytes_per_page * 4
path = str(tmp_path / "oob_write")
client = MockHf3fsClient(
path=path, size=size, bytes_per_page=bytes_per_page, entries=4
)
numel = bytes_per_page // 4
tensor = torch.ones(numel, dtype=torch.float32)
event = _make_cuda_event()
results = client.batch_write([size], [tensor], event) # OOB offset
assert results[0] == -1, "Out-of-bounds write should return -1"
client.close()
def test_multiple_tensors_rw(self, tmp_path):
"""Write multiple tensors at different offsets, then read all back."""
bytes_per_page = 128
n = 4
path = str(tmp_path / "multi_rw")
client = MockHf3fsClient(
path=path,
size=bytes_per_page * n * 2,
bytes_per_page=bytes_per_page,
entries=8,
)
tensors_write = [
torch.full((bytes_per_page // 4,), float(i), dtype=torch.float32)
for i in range(n)
]
offsets = [i * bytes_per_page for i in range(n)]
event = _make_cuda_event()
results = client.batch_write(offsets, tensors_write, event)
assert all(r == bytes_per_page for r in results)
tensors_read = [
torch.zeros(bytes_per_page // 4, dtype=torch.float32) for _ in range(n)
]
results = client.batch_read(offsets, tensors_read)
assert all(r == bytes_per_page for r in results)
for i, (tw, tr) in enumerate(zip(tensors_write, tensors_read)):
assert torch.allclose(tw, tr), f"Tensor {i} mismatch after round-trip"
client.close()
def test_flush_and_close_no_error(self, tmp_path):
"""flush() and close() should not raise exceptions."""
path = str(tmp_path / "flush_close")
client = MockHf3fsClient(path=path, size=1024, bytes_per_page=128, entries=4)
client.flush()
client.close()
# ===========================================================================
# TestHF3FSKVConnectorStats
# ===========================================================================
class TestHF3FSKVConnectorStats:
"""Tests for HF3FSKVConnectorStats metric collection and aggregation."""
def test_initial_is_empty(self, hf3fs_stats):
"""Fresh stats object should report is_empty() == True."""
assert hf3fs_stats.is_empty() is True
@pytest.mark.parametrize(
"task_type, duration_key",
[
("Saved", "save_duration"),
("Loaded", "load_duration"),
],
ids=["save", "load"],
)
def test_record_success_duration(self, hf3fs_stats, task_type, duration_key):
"""Recording a successful task should update duration list and total count."""
hf3fs_stats.record_success_task_duration(task_type, 0.5)
assert not hf3fs_stats.is_empty()
assert len(hf3fs_stats.data[duration_key]) == 1
assert hf3fs_stats.data[duration_key][0] == pytest.approx(0.5)
assert hf3fs_stats.data["num_transfer_task"] == 1
@pytest.mark.parametrize(
"task_type, failed_key",
[
("Saved", "num_failed_save"),
("Loaded", "num_failed_load"),
],
ids=["save", "load"],
)
def test_record_failed_task(self, hf3fs_stats, task_type, failed_key):
"""Recording a failed task should increment the corresponding counter."""
hf3fs_stats.record_failed_task_count(task_type)
assert hf3fs_stats.data[failed_key] == 1
assert hf3fs_stats.data["num_transfer_task"] == 1
def test_aggregate_two_stats(self):
"""aggregate() should merge save/load duration lists and sum counters."""
stats1 = HF3FSKVConnectorStats()
stats1.record_success_task_duration("Saved", 0.1)
stats1.record_success_task_duration("Loaded", 0.2)
stats2 = HF3FSKVConnectorStats()
stats2.record_success_task_duration("Saved", 0.3)
stats2.record_failed_task_count("Loaded")
stats1.aggregate(stats2)
assert stats1.data["save_duration"] == pytest.approx([0.1, 0.3])
assert stats1.data["load_duration"] == pytest.approx([0.2])
assert stats1.data["num_failed_load"] == 1
assert stats1.data["num_transfer_task"] == 4
def test_reduce_with_data(self):
"""reduce() computes correct averages when data is present."""
stats = HF3FSKVConnectorStats()
stats.record_success_task_duration("Saved", 1.0)
stats.record_success_task_duration("Saved", 3.0)
result = stats.reduce()
assert result["Num save task success"] == pytest.approx(2.0, rel=0.01)
assert result["Num save task failed"] == pytest.approx(0.0, rel=0.01)
assert result["Avg save duration (ms)"] == pytest.approx(2000.0, rel=0.01)
def test_clone_and_reset(self, hf3fs_stats):
"""clone_and_reset() returns a copy with data and resets the original."""
hf3fs_stats.record_success_task_duration("Saved", 0.7)
hf3fs_stats.record_success_task_duration("Loaded", 0.4)
clone = hf3fs_stats.clone_and_reset()
assert clone.data["num_transfer_task"] == 2
assert hf3fs_stats.is_empty()

View File

@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for HF3FS metadata server data structures and allocation logic:
- RankFileMetadata : page allocation / release primitives
- KeyMetadata : per-key rank-page tracking and completion detection
- GlobalMetadataState : coordinated allocation with cache-hit semantics
"""
import pytest
from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_metadata_server import (
GlobalMetadataState,
KeyMetadata,
RankFileMetadata,
)
# ===========================================================================
# TestRankFileMetadata
# ===========================================================================
class TestRankFileMetadata:
"""Unit tests for RankFileMetadata page allocation primitives."""
@pytest.mark.parametrize(
"alloc_count, expected_pages",
[(3, 3), (5, 0)],
ids=["alloc_partial", "alloc_exceeds"],
)
def test_allocate_pages(self, alloc_count, expected_pages):
"""allocate_pages returns correct pages or empty list when insufficient."""
rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3)))
pages = rank_meta.allocate_pages(alloc_count)
assert len(pages) == expected_pages
if expected_pages > 0:
rank_meta.release_pages(pages)
assert rank_meta.get_free_page_count() == 3
def test_release_pages_restores_count(self):
"""Releasing allocated pages returns them to the free pool."""
rank_meta = RankFileMetadata(rank_id=0, num_pages=4, free_pages=list(range(4)))
pages = rank_meta.allocate_pages(2)
assert rank_meta.get_free_page_count() == 2
rank_meta.release_pages(pages)
assert rank_meta.get_free_page_count() == 4
def test_release_pages_no_duplicates(self):
"""Releasing the same page twice must not create duplicates."""
rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3)))
rank_meta.allocate_pages(1) # takes page 0
rank_meta.release_pages([0])
rank_meta.release_pages([0]) # second release of the same page
assert rank_meta.get_free_page_count() == 3
# ===========================================================================
# TestKeyMetadata
# ===========================================================================
class TestKeyMetadata:
"""Unit tests for KeyMetadata completion tracking."""
def test_is_complete_false_until_all_ranks(self):
"""is_complete() returns True only when all ranks confirmed."""
key_meta = KeyMetadata(key="k", rank_to_page={}, tp_world_size=2)
assert key_meta.is_complete() is False
key_meta.add_rank_page(0, 5)
assert key_meta.is_complete() is False
key_meta.add_rank_page(1, 10)
assert key_meta.is_complete() is True
def test_get_rank_page_returns_none_for_missing_rank(self):
"""get_rank_page() returns None when the rank has no entry."""
key_meta = KeyMetadata(key="k", rank_to_page={0: 3}, tp_world_size=2)
assert key_meta.get_rank_page(0) == 3
assert key_meta.get_rank_page(1) is None
def test_get_all_pages(self):
"""get_all_pages() returns all (rank, page) pairs."""
key_meta = KeyMetadata(key="k", rank_to_page={0: 1, 1: 2}, tp_world_size=2)
pairs = key_meta.get_all_pages()
assert set(pairs) == {(0, 1), (1, 2)}
# ===========================================================================
# TestGlobalMetadataStateAllocation
# ===========================================================================
class TestGlobalMetadataStateAllocation:
"""Tests for GlobalMetadataState allocation and cache-hit semantics."""
def test_uninitialized_rank_raises_on_allocate(self):
"""allocate_pages_for_keys raises ValueError for unknown rank."""
state = GlobalMetadataState()
with pytest.raises((ValueError, Exception)):
state.allocate_pages_for_keys(99, [("key", "")])
def test_uninitialized_rank_raises_on_get_locations(self):
"""get_key_locations raises ValueError for unknown rank."""
state = GlobalMetadataState()
with pytest.raises((ValueError, Exception)):
state.get_key_locations(99, ["any_key"])
def test_basic_allocation_and_confirm(self):
"""Allocating a page and confirming it marks the key as complete."""
state = GlobalMetadataState()
state.initialize_rank(0, 4)
results = state.allocate_pages_for_keys(0, [("K", "")])
assert results["K"] >= 0
state.confirm_write_for_keys(0, [("K", results["K"])])
assert state.batch_key_exists(["K"]) == [True]
locations = state.get_key_locations(0, ["K"])
assert locations == [results["K"]]
def test_allocate_pages_cache_hit_does_not_leak_pages(self):
"""Cache-hit key must not consume a page from the free pool;
the pre-allocated slot must be returned before reusing the existing page.
"""
state = GlobalMetadataState()
state.initialize_rank(0, 5) # 5 free pages: [0,1,2,3,4]
# Simulate a key that has already been fully written and confirmed.
state.key_metadata["K_cached"] = KeyMetadata(
key="K_cached", rank_to_page={0: 2}, tp_world_size=1
)
free_before = state.rank_metadata[0].get_free_page_count() # 5
results = state.allocate_pages_for_keys(0, [("K_cached", ""), ("K_new", "")])
free_after = state.rank_metadata[0].get_free_page_count()
# Cache-hit key must reuse its existing page.
assert results["K_cached"] == 2, (
f"Cache-hit key should reuse page 2, got {results['K_cached']}"
)
# New key must receive a valid page.
assert results["K_new"] >= 0, (
f"New key should get a valid page, got {results['K_new']}"
)
# Exactly one page consumed from the free pool.
assert free_before - free_after == 1, (
f"Expected 1 page consumed, got delta={free_before - free_after}"
)
def test_allocate_pages_all_cache_hits_frees_all_slots(self):
"""When every key in the batch is a cache hit, no pages are consumed."""
state = GlobalMetadataState()
state.initialize_rank(0, 5)
for key, page in (("K1", 0), ("K2", 1)):
state.key_metadata[key] = KeyMetadata(
key=key, rank_to_page={0: page}, tp_world_size=1
)
free_before = state.rank_metadata[0].get_free_page_count()
results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")])
free_after = state.rank_metadata[0].get_free_page_count()
assert results["K1"] == 0
assert results["K2"] == 1
assert free_after == free_before, (
f"All-cache-hit batch must not consume free pages; "
f"before={free_before}, after={free_after}"
)
def test_allocate_returns_minus_one_when_pool_exhausted(self):
"""If the free pool is exhausted, all new keys receive -1."""
state = GlobalMetadataState()
state.initialize_rank(0, 1) # only 1 free page
results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")])
# allocate_pages uses all-or-nothing: 2 needed but only 1 available → []
assert all(v == -1 for v in results.values()), f"Expected all -1, got {results}"
def test_confirm_write_releases_pages(self):
"""confirm_write_for_keys with pages_to_release returns them to pool."""
state = GlobalMetadataState()
state.initialize_rank(0, 3)
results = state.allocate_pages_for_keys(0, [("K", "")])
page = results["K"]
free_after_alloc = state.rank_metadata[0].get_free_page_count()
state.confirm_write_for_keys(0, [("K", page)], pages_to_release=[page])
free_after_release = state.rank_metadata[0].get_free_page_count()
assert free_after_release == free_after_alloc + 1