diff --git a/setup.py b/setup.py index a4c9e85ff..06c2f2474 100644 --- a/setup.py +++ b/setup.py @@ -1013,6 +1013,7 @@ package_data = { "model_executor/layers/quantization/utils/configs/*.json", "entrypoints/serve/instrumentator/static/*.js", "entrypoints/serve/instrumentator/static/*.css", + "distributed/kv_transfer/kv_connector/v1/hf3fs/utils/*.cpp", ] } diff --git a/tests/v1/kv_connector/unit/test_hf3fs_client.py b/tests/v1/kv_connector/unit/test_hf3fs_client.py new file mode 100644 index 000000000..d9c34a890 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_hf3fs_client.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for resource management in hf3fs_client.py: constructor failure cleanup +and idempotent close(). Tests use mock to replace real I/O operations +(hf3fs_fuse.io, SharedMemory, os, CUDA). +Requires hf3fs_fuse.io to be installed; skipped otherwise. +""" + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +HF3FS_AVAILABLE = True +try: + from hf3fs_fuse.io import ( # noqa: F401 + deregister_fd, + extract_mount_point, + make_ioring, + make_iovec, + register_fd, + ) + + from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client import ( + Hf3fsClient, + ) +except Exception: + HF3FS_AVAILABLE = False + +requires_hf3fs = pytest.mark.skipif( + not HF3FS_AVAILABLE, + reason="hf3fs_fuse.io is not available on this machine", +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _FakeShm: + """Shared-memory stub matching the multiprocessing.shared_memory.SharedMemory + interface used by Hf3fsClient: + + Attributes accessed by the constructor: + .buf – memoryview / buffer-protocol object consumed by torch.frombuffer + Methods called during normal lifetime: + .unlink() – called right after the iovec is set up + .close() – called in _release_resources() + """ + + def __init__(self, size: int = 1024): + self._data = bytearray(size) + self.buf = memoryview(self._data) + self.closed = False + self.close_call_count = 0 + self.unlink_call_count = 0 + + def close(self): + self.closed = True + self.close_call_count += 1 + + def unlink(self): + self.unlink_call_count += 1 + + +# =========================================================================== +# TestHf3fsClientResourceManagement +# =========================================================================== + + +@requires_hf3fs +class TestHf3fsClientResourceManagement: + """Tests for constructor failure cleanup and idempotent close().""" + + _MOD = "vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client" + + # ------------------------------------------------------------------ + # Helper: build a minimal Hf3fsClient bypassing all real I/O so that + # we can fully control its internal state. + # ------------------------------------------------------------------ + + def _make_client(self, tmp_path): + """Return a fully-mocked Hf3fsClient with controllable internals.""" + fake_shm_r = _FakeShm() + fake_shm_w = _FakeShm() + + patcher_list: list[Any] = [ + patch(f"{self._MOD}.HF3FS_AVAILABLE", True), + patch(f"{self._MOD}.register_fd"), + patch(f"{self._MOD}.deregister_fd"), + patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"), + patch(f"{self._MOD}.make_ioring", return_value=MagicMock()), + patch(f"{self._MOD}.make_iovec", return_value=MagicMock()), + patch( + "multiprocessing.shared_memory.SharedMemory", + side_effect=[fake_shm_r, fake_shm_w], + ), + patch("os.open", return_value=99), + patch("os.ftruncate"), + patch("os.close"), + patch("os.fsync"), + patch("torch.cuda.Stream", return_value=MagicMock()), + patch("torch.frombuffer", return_value=MagicMock()), + patch("torch.empty", return_value=MagicMock()), + ] + for p in patcher_list: + p.start() + + try: + client = Hf3fsClient( + path=str(tmp_path / "test.bin"), + size=1024, + bytes_per_page=256, + entries=4, + ) + finally: + for p in patcher_list: + p.stop() + + # Manually point internal handles to our controllable fakes so that + # assertions after close() can inspect them directly. + client.shm_r = fake_shm_r + client.shm_w = fake_shm_w + client.file = 99 + return client, fake_shm_r, fake_shm_w + + # ------------------------------------------------------------------ + # close() idempotency + # ------------------------------------------------------------------ + + def test_close_idempotent_and_handles_cleared(self, tmp_path): + """Multiple close() calls must not raise; deregister_fd called exactly + once, all handles set to None, shm.close() invoked.""" + client, shm_r, shm_w = self._make_client(tmp_path) + + with ( + patch(f"{self._MOD}.deregister_fd") as mock_dereg, + patch("os.close"), + ): + client.close() # first close + client.close() # second close — must be no-op + client.close() # third close — must be no-op + + assert client._closed is True + assert mock_dereg.call_count == 1, ( + f"deregister_fd called {mock_dereg.call_count} times; expected 1" + ) + for attr in ("iov_r", "iov_w", "ior_r", "ior_w", "shm_r", "shm_w", "file"): + assert getattr(client, attr) is None, f"{attr} should be None after close()" + assert shm_r.closed is True + assert shm_w.closed is True + + def test_flush_after_close_is_noop(self, tmp_path): + """flush() after close() must silently do nothing (no fsync call).""" + client, _, _ = self._make_client(tmp_path) + + with ( + patch(f"{self._MOD}.deregister_fd"), + patch("os.close"), + patch("os.fsync") as mock_fsync, + ): + client.close() + client.flush() + + mock_fsync.assert_not_called() + + # ------------------------------------------------------------------ + # Constructor failure leaves no leaked resources + # ------------------------------------------------------------------ + + def test_constructor_failure_after_file_open_cleans_file(self, tmp_path): + """If the constructor raises after os.open(), the fd must be closed.""" + with ( + patch(f"{self._MOD}.HF3FS_AVAILABLE", True), + patch(f"{self._MOD}.register_fd"), + patch(f"{self._MOD}.deregister_fd"), + patch( + f"{self._MOD}.extract_mount_point", + side_effect=RuntimeError("mount point not found"), + ), + patch("os.open", return_value=55), + patch("os.ftruncate"), + patch("os.close") as mock_os_close, + patch("torch.cuda.Stream", return_value=MagicMock()), + pytest.raises(RuntimeError, match="mount point not found"), + ): + Hf3fsClient( + path=str(tmp_path / "fail.bin"), + size=1024, + bytes_per_page=256, + entries=4, + ) + + mock_os_close.assert_called_once_with(55) + + def test_constructor_failure_after_shm_alloc_closes_shm(self, tmp_path): + """Constructor raises after SharedMemory creation → both shm objects closed.""" + fake_shm_r = _FakeShm() + fake_shm_w = _FakeShm() + + with ( + patch(f"{self._MOD}.HF3FS_AVAILABLE", True), + patch(f"{self._MOD}.register_fd"), + patch(f"{self._MOD}.deregister_fd"), + patch(f"{self._MOD}.extract_mount_point", return_value="/mnt/hf3fs"), + patch( + "multiprocessing.shared_memory.SharedMemory", + side_effect=[fake_shm_r, fake_shm_w], + ), + patch("os.open", return_value=66), + patch("os.ftruncate"), + patch("os.close"), + patch("torch.frombuffer", return_value=MagicMock()), + patch("torch.empty", return_value=MagicMock()), + patch( + f"{self._MOD}.make_ioring", + side_effect=RuntimeError("ioring init failed"), + ), + patch(f"{self._MOD}.make_iovec", return_value=MagicMock()), + patch("torch.cuda.Stream", return_value=MagicMock()), + pytest.raises(RuntimeError, match="ioring init failed"), + ): + Hf3fsClient( + path=str(tmp_path / "fail2.bin"), + size=1024, + bytes_per_page=256, + entries=4, + ) + + assert fake_shm_r.closed is True, ( + "shm_r was not closed after constructor failure" + ) + assert fake_shm_w.closed is True, ( + "shm_w was not closed after constructor failure" + ) + + def test_constructor_failure_does_not_close_unallocated_shm(self, tmp_path): + """Failure before SharedMemory is created must not raise AttributeError + or TypeError from cleanup.""" + with ( + patch(f"{self._MOD}.HF3FS_AVAILABLE", True), + patch(f"{self._MOD}.register_fd"), + patch(f"{self._MOD}.deregister_fd"), + patch( + f"{self._MOD}.extract_mount_point", + side_effect=RuntimeError("early failure"), + ), + patch("os.open", return_value=77), + patch("os.ftruncate"), + patch("os.close"), + patch("torch.cuda.Stream", return_value=MagicMock()), + pytest.raises(RuntimeError, match="early failure"), + ): + Hf3fsClient( + path=str(tmp_path / "early_fail.bin"), + size=1024, + bytes_per_page=256, + entries=4, + ) + + # ------------------------------------------------------------------ + # _release_resources on already-cleared state must be a no-op + # ------------------------------------------------------------------ + + def test_release_resources_on_empty_state_is_safe(self, tmp_path): + """_release_resources() on a fully-cleared client must not raise.""" + client, _, _ = self._make_client(tmp_path) + + with ( + patch(f"{self._MOD}.deregister_fd"), + patch("os.close"), + ): + client.close() # clears all handles + + with ( + patch(f"{self._MOD}.deregister_fd") as mock_dereg2, + patch("os.close") as mock_os_close2, + ): + client._release_resources() # must not raise + + mock_dereg2.assert_not_called() + mock_os_close2.assert_not_called() diff --git a/tests/v1/kv_connector/unit/test_hf3fs_connector.py b/tests/v1/kv_connector/unit/test_hf3fs_connector.py new file mode 100644 index 000000000..94bb94c6f --- /dev/null +++ b/tests/v1/kv_connector/unit/test_hf3fs_connector.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for HF3FS KV Connector high-level components: + - TestHf3fsMockClient : file-backed mock client I/O correctness + - TestHF3FSKVConnectorStats: metric collection, aggregation, serialisation +""" + +import os +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_connector import ( + HF3FSKVConnectorStats, +) +from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.hf3fs_mock_client import ( + Hf3fsClient as MockHf3fsClient, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture +def hf3fs_stats(): + """Fresh HF3FSKVConnectorStats instance.""" + return HF3FSKVConnectorStats() + + +def _make_cuda_event(): + """Return a real CUDA event when available, otherwise a MagicMock.""" + if torch.cuda.is_available(): + return torch.cuda.Event() + return MagicMock() + + +# =========================================================================== +# TestHf3fsMockClient +# =========================================================================== + + +class TestHf3fsMockClient: + """Tests for hf3fs_mock_client.Hf3fsClient (file-backend mock).""" + + def test_init_creates_file(self, tmp_path): + """Initializing the client should create the backing file.""" + path = str(tmp_path / "test_file") + client = MockHf3fsClient(path=path, size=4096, bytes_per_page=512, entries=4) + assert os.path.exists(path), "Backing file should be created on init" + assert os.path.getsize(path) == 4096 + client.close() + + @pytest.mark.parametrize( + "dtype, bytes_per_page", + [ + (torch.float32, 512), + (torch.float16, 256), + (torch.bfloat16, 256), + ], + ids=["float32", "float16", "bfloat16"], + ) + def test_batch_write_and_read_dtype(self, tmp_path, dtype, bytes_per_page): + """Write a tensor of the given dtype and verify round-trip correctness.""" + path = str(tmp_path / f"rw_{dtype}") + client = MockHf3fsClient( + path=path, size=bytes_per_page * 8, bytes_per_page=bytes_per_page, entries=4 + ) + elem_size = torch.tensor([], dtype=dtype).element_size() + numel = bytes_per_page // elem_size + tensor_write = torch.arange(numel, dtype=dtype) + event = _make_cuda_event() + + results = client.batch_write([0], [tensor_write], event) + assert results == [bytes_per_page], f"Write should succeed, got {results}" + + tensor_read = torch.zeros(numel, dtype=dtype) + results = client.batch_read([0], [tensor_read]) + assert results == [bytes_per_page], f"Read should succeed, got {results}" + assert torch.equal(tensor_write, tensor_read), ( + "Read tensor should match written tensor" + ) + client.close() + + def test_batch_read_empty_file_returns_error(self, tmp_path): + """Reading out-of-bounds offset should return -1.""" + bytes_per_page = 128 + size = bytes_per_page * 4 + path = str(tmp_path / "empty_read") + client = MockHf3fsClient( + path=path, size=size, bytes_per_page=bytes_per_page, entries=4 + ) + numel = bytes_per_page // 4 + tensor_read = torch.zeros(numel, dtype=torch.float32) + results = client.batch_read([size], [tensor_read]) # offset == size => OOB + assert results[0] == -1, "Out-of-bounds read should return -1" + client.close() + + def test_batch_write_out_of_bounds_returns_error(self, tmp_path): + """Writing at an offset beyond file size should return -1.""" + bytes_per_page = 128 + size = bytes_per_page * 4 + path = str(tmp_path / "oob_write") + client = MockHf3fsClient( + path=path, size=size, bytes_per_page=bytes_per_page, entries=4 + ) + numel = bytes_per_page // 4 + tensor = torch.ones(numel, dtype=torch.float32) + event = _make_cuda_event() + results = client.batch_write([size], [tensor], event) # OOB offset + assert results[0] == -1, "Out-of-bounds write should return -1" + client.close() + + def test_multiple_tensors_rw(self, tmp_path): + """Write multiple tensors at different offsets, then read all back.""" + bytes_per_page = 128 + n = 4 + path = str(tmp_path / "multi_rw") + client = MockHf3fsClient( + path=path, + size=bytes_per_page * n * 2, + bytes_per_page=bytes_per_page, + entries=8, + ) + tensors_write = [ + torch.full((bytes_per_page // 4,), float(i), dtype=torch.float32) + for i in range(n) + ] + offsets = [i * bytes_per_page for i in range(n)] + event = _make_cuda_event() + + results = client.batch_write(offsets, tensors_write, event) + assert all(r == bytes_per_page for r in results) + + tensors_read = [ + torch.zeros(bytes_per_page // 4, dtype=torch.float32) for _ in range(n) + ] + results = client.batch_read(offsets, tensors_read) + assert all(r == bytes_per_page for r in results) + + for i, (tw, tr) in enumerate(zip(tensors_write, tensors_read)): + assert torch.allclose(tw, tr), f"Tensor {i} mismatch after round-trip" + client.close() + + def test_flush_and_close_no_error(self, tmp_path): + """flush() and close() should not raise exceptions.""" + path = str(tmp_path / "flush_close") + client = MockHf3fsClient(path=path, size=1024, bytes_per_page=128, entries=4) + client.flush() + client.close() + + +# =========================================================================== +# TestHF3FSKVConnectorStats +# =========================================================================== + + +class TestHF3FSKVConnectorStats: + """Tests for HF3FSKVConnectorStats metric collection and aggregation.""" + + def test_initial_is_empty(self, hf3fs_stats): + """Fresh stats object should report is_empty() == True.""" + assert hf3fs_stats.is_empty() is True + + @pytest.mark.parametrize( + "task_type, duration_key", + [ + ("Saved", "save_duration"), + ("Loaded", "load_duration"), + ], + ids=["save", "load"], + ) + def test_record_success_duration(self, hf3fs_stats, task_type, duration_key): + """Recording a successful task should update duration list and total count.""" + hf3fs_stats.record_success_task_duration(task_type, 0.5) + assert not hf3fs_stats.is_empty() + assert len(hf3fs_stats.data[duration_key]) == 1 + assert hf3fs_stats.data[duration_key][0] == pytest.approx(0.5) + assert hf3fs_stats.data["num_transfer_task"] == 1 + + @pytest.mark.parametrize( + "task_type, failed_key", + [ + ("Saved", "num_failed_save"), + ("Loaded", "num_failed_load"), + ], + ids=["save", "load"], + ) + def test_record_failed_task(self, hf3fs_stats, task_type, failed_key): + """Recording a failed task should increment the corresponding counter.""" + hf3fs_stats.record_failed_task_count(task_type) + assert hf3fs_stats.data[failed_key] == 1 + assert hf3fs_stats.data["num_transfer_task"] == 1 + + def test_aggregate_two_stats(self): + """aggregate() should merge save/load duration lists and sum counters.""" + stats1 = HF3FSKVConnectorStats() + stats1.record_success_task_duration("Saved", 0.1) + stats1.record_success_task_duration("Loaded", 0.2) + + stats2 = HF3FSKVConnectorStats() + stats2.record_success_task_duration("Saved", 0.3) + stats2.record_failed_task_count("Loaded") + + stats1.aggregate(stats2) + assert stats1.data["save_duration"] == pytest.approx([0.1, 0.3]) + assert stats1.data["load_duration"] == pytest.approx([0.2]) + assert stats1.data["num_failed_load"] == 1 + assert stats1.data["num_transfer_task"] == 4 + + def test_reduce_with_data(self): + """reduce() computes correct averages when data is present.""" + stats = HF3FSKVConnectorStats() + stats.record_success_task_duration("Saved", 1.0) + stats.record_success_task_duration("Saved", 3.0) + result = stats.reduce() + assert result["Num save task success"] == pytest.approx(2.0, rel=0.01) + assert result["Num save task failed"] == pytest.approx(0.0, rel=0.01) + assert result["Avg save duration (ms)"] == pytest.approx(2000.0, rel=0.01) + + def test_clone_and_reset(self, hf3fs_stats): + """clone_and_reset() returns a copy with data and resets the original.""" + hf3fs_stats.record_success_task_duration("Saved", 0.7) + hf3fs_stats.record_success_task_duration("Loaded", 0.4) + + clone = hf3fs_stats.clone_and_reset() + assert clone.data["num_transfer_task"] == 2 + assert hf3fs_stats.is_empty() diff --git a/tests/v1/kv_connector/unit/test_hf3fs_metadata_server.py b/tests/v1/kv_connector/unit/test_hf3fs_metadata_server.py new file mode 100644 index 000000000..f922c7c85 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_hf3fs_metadata_server.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for HF3FS metadata server data structures and allocation logic: + - RankFileMetadata : page allocation / release primitives + - KeyMetadata : per-key rank-page tracking and completion detection + - GlobalMetadataState : coordinated allocation with cache-hit semantics +""" + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_metadata_server import ( + GlobalMetadataState, + KeyMetadata, + RankFileMetadata, +) + +# =========================================================================== +# TestRankFileMetadata +# =========================================================================== + + +class TestRankFileMetadata: + """Unit tests for RankFileMetadata page allocation primitives.""" + + @pytest.mark.parametrize( + "alloc_count, expected_pages", + [(3, 3), (5, 0)], + ids=["alloc_partial", "alloc_exceeds"], + ) + def test_allocate_pages(self, alloc_count, expected_pages): + """allocate_pages returns correct pages or empty list when insufficient.""" + rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3))) + pages = rank_meta.allocate_pages(alloc_count) + assert len(pages) == expected_pages + if expected_pages > 0: + rank_meta.release_pages(pages) + assert rank_meta.get_free_page_count() == 3 + + def test_release_pages_restores_count(self): + """Releasing allocated pages returns them to the free pool.""" + rank_meta = RankFileMetadata(rank_id=0, num_pages=4, free_pages=list(range(4))) + pages = rank_meta.allocate_pages(2) + assert rank_meta.get_free_page_count() == 2 + rank_meta.release_pages(pages) + assert rank_meta.get_free_page_count() == 4 + + def test_release_pages_no_duplicates(self): + """Releasing the same page twice must not create duplicates.""" + rank_meta = RankFileMetadata(rank_id=0, num_pages=3, free_pages=list(range(3))) + rank_meta.allocate_pages(1) # takes page 0 + rank_meta.release_pages([0]) + rank_meta.release_pages([0]) # second release of the same page + assert rank_meta.get_free_page_count() == 3 + + +# =========================================================================== +# TestKeyMetadata +# =========================================================================== + + +class TestKeyMetadata: + """Unit tests for KeyMetadata completion tracking.""" + + def test_is_complete_false_until_all_ranks(self): + """is_complete() returns True only when all ranks confirmed.""" + key_meta = KeyMetadata(key="k", rank_to_page={}, tp_world_size=2) + assert key_meta.is_complete() is False + key_meta.add_rank_page(0, 5) + assert key_meta.is_complete() is False + key_meta.add_rank_page(1, 10) + assert key_meta.is_complete() is True + + def test_get_rank_page_returns_none_for_missing_rank(self): + """get_rank_page() returns None when the rank has no entry.""" + key_meta = KeyMetadata(key="k", rank_to_page={0: 3}, tp_world_size=2) + assert key_meta.get_rank_page(0) == 3 + assert key_meta.get_rank_page(1) is None + + def test_get_all_pages(self): + """get_all_pages() returns all (rank, page) pairs.""" + key_meta = KeyMetadata(key="k", rank_to_page={0: 1, 1: 2}, tp_world_size=2) + pairs = key_meta.get_all_pages() + assert set(pairs) == {(0, 1), (1, 2)} + + +# =========================================================================== +# TestGlobalMetadataStateAllocation +# =========================================================================== + + +class TestGlobalMetadataStateAllocation: + """Tests for GlobalMetadataState allocation and cache-hit semantics.""" + + def test_uninitialized_rank_raises_on_allocate(self): + """allocate_pages_for_keys raises ValueError for unknown rank.""" + state = GlobalMetadataState() + with pytest.raises((ValueError, Exception)): + state.allocate_pages_for_keys(99, [("key", "")]) + + def test_uninitialized_rank_raises_on_get_locations(self): + """get_key_locations raises ValueError for unknown rank.""" + state = GlobalMetadataState() + with pytest.raises((ValueError, Exception)): + state.get_key_locations(99, ["any_key"]) + + def test_basic_allocation_and_confirm(self): + """Allocating a page and confirming it marks the key as complete.""" + state = GlobalMetadataState() + state.initialize_rank(0, 4) + + results = state.allocate_pages_for_keys(0, [("K", "")]) + assert results["K"] >= 0 + + state.confirm_write_for_keys(0, [("K", results["K"])]) + assert state.batch_key_exists(["K"]) == [True] + locations = state.get_key_locations(0, ["K"]) + assert locations == [results["K"]] + + def test_allocate_pages_cache_hit_does_not_leak_pages(self): + """Cache-hit key must not consume a page from the free pool; + the pre-allocated slot must be returned before reusing the existing page. + """ + state = GlobalMetadataState() + state.initialize_rank(0, 5) # 5 free pages: [0,1,2,3,4] + + # Simulate a key that has already been fully written and confirmed. + state.key_metadata["K_cached"] = KeyMetadata( + key="K_cached", rank_to_page={0: 2}, tp_world_size=1 + ) + + free_before = state.rank_metadata[0].get_free_page_count() # 5 + + results = state.allocate_pages_for_keys(0, [("K_cached", ""), ("K_new", "")]) + + free_after = state.rank_metadata[0].get_free_page_count() + + # Cache-hit key must reuse its existing page. + assert results["K_cached"] == 2, ( + f"Cache-hit key should reuse page 2, got {results['K_cached']}" + ) + # New key must receive a valid page. + assert results["K_new"] >= 0, ( + f"New key should get a valid page, got {results['K_new']}" + ) + # Exactly one page consumed from the free pool. + assert free_before - free_after == 1, ( + f"Expected 1 page consumed, got delta={free_before - free_after}" + ) + + def test_allocate_pages_all_cache_hits_frees_all_slots(self): + """When every key in the batch is a cache hit, no pages are consumed.""" + state = GlobalMetadataState() + state.initialize_rank(0, 5) + + for key, page in (("K1", 0), ("K2", 1)): + state.key_metadata[key] = KeyMetadata( + key=key, rank_to_page={0: page}, tp_world_size=1 + ) + + free_before = state.rank_metadata[0].get_free_page_count() + results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")]) + free_after = state.rank_metadata[0].get_free_page_count() + + assert results["K1"] == 0 + assert results["K2"] == 1 + assert free_after == free_before, ( + f"All-cache-hit batch must not consume free pages; " + f"before={free_before}, after={free_after}" + ) + + def test_allocate_returns_minus_one_when_pool_exhausted(self): + """If the free pool is exhausted, all new keys receive -1.""" + state = GlobalMetadataState() + state.initialize_rank(0, 1) # only 1 free page + + results = state.allocate_pages_for_keys(0, [("K1", ""), ("K2", "")]) + # allocate_pages uses all-or-nothing: 2 needed but only 1 available → [] + assert all(v == -1 for v in results.values()), f"Expected all -1, got {results}" + + def test_confirm_write_releases_pages(self): + """confirm_write_for_keys with pages_to_release returns them to pool.""" + state = GlobalMetadataState() + state.initialize_rank(0, 3) + + results = state.allocate_pages_for_keys(0, [("K", "")]) + page = results["K"] + free_after_alloc = state.rank_metadata[0].get_free_page_count() + + state.confirm_write_for_keys(0, [("K", page)], pages_to_release=[page]) + free_after_release = state.rank_metadata[0].get_free_page_count() + + assert free_after_release == free_after_alloc + 1 diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 9a6be93a2..9f8379fec 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -211,15 +211,18 @@ KVConnectorFactory.register_connector( "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector", "MooncakeConnector", ) - KVConnectorFactory.register_connector( "FlexKVConnectorV1", "vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector", "FlexKVConnectorV1", ) - KVConnectorFactory.register_connector( "SimpleCPUOffloadConnector", "vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector", "SimpleCPUOffloadConnector", ) +KVConnectorFactory.register_connector( + "HF3FSKVConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_connector", + "HF3FSKVConnector", +) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/hf3fs_client.py b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/hf3fs_client.py new file mode 100644 index 000000000..a54233453 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/hf3fs_client.py @@ -0,0 +1,298 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +import multiprocessing +import os +import threading +from functools import wraps +from pathlib import Path + +import torch +import torch.utils.cpp_extension +from torch.utils.cpp_extension import load + +root = Path(__file__).parent.resolve() +cuda_include_path = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include") +hf3fs_utils = load( + name="hf3fs_utils", + sources=[f"{root}/utils/hf3fs_utils.cpp"], + extra_include_paths=[cuda_include_path], +) + +logger = logging.getLogger(__name__) + +HF3FS_AVAILABLE = True +try: + from hf3fs_fuse.io import ( + deregister_fd, + extract_mount_point, + make_ioring, + make_iovec, + register_fd, + ) +except ImportError: + HF3FS_AVAILABLE = False + + +def rsynchronized(): + def _decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with self.rlock: + return func(self, *args, **kwargs) + + return wrapper + + return _decorator + + +def wsynchronized(): + def _decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with self.wlock: + return func(self, *args, **kwargs) + + return wrapper + + return _decorator + + +class Hf3fsClient: + def __init__(self, path: str, size: int, bytes_per_page: int, entries: int): + """Initialize the HF3FS client with hf3fs_fuse. + + Args: + path: Path to the file used for storage + size: Total size of the storage file in bytes + bytes_per_page: Size of each page in bytes + entries: Maximum number of concurrent operations + """ + if not HF3FS_AVAILABLE: + raise ImportError( + "hf3fs_fuse.io is not available. Please install the hf3fs_fuse package." + ) + + self.path = path + self.size = size + self.bytes_per_page = bytes_per_page + self.entries = entries + + self._closed = False + + self.file = None + self.shm_r = None + self.shm_w = None + self.ior_r = None + self.ior_w = None + self.iov_r = None + self.iov_w = None + try: + # Create the file if it doesn't exist and set its size + self.file = os.open(self.path, os.O_RDWR | os.O_CREAT) + os.ftruncate(self.file, size) + register_fd(self.file) + + self.hf3fs_mount_point = extract_mount_point(path) + self.bs = self.bytes_per_page + self.shm_r = multiprocessing.shared_memory.SharedMemory( + size=self.bs * self.entries, create=True + ) + self.shm_w = multiprocessing.shared_memory.SharedMemory( + size=self.bs * self.entries, create=True + ) + + self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8) + self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8) + + numel = self.bs * self.entries + self.r_pinned = torch.empty( + numel, + dtype=torch.uint8, + device="cpu", + pin_memory=True, + ) + self.w_pinned = torch.empty( + numel, + dtype=torch.uint8, + device="cpu", + pin_memory=True, + ) + + self.numa = -1 + self.ior_r = make_ioring( + self.hf3fs_mount_point, + self.entries, + for_read=True, + timeout=1, + numa=self.numa, + ) + self.ior_w = make_ioring( + self.hf3fs_mount_point, + self.entries, + for_read=False, + timeout=1, + numa=self.numa, + ) + self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point) + self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point) + self.shm_r.unlink() + self.shm_w.unlink() + + self.rlock = threading.RLock() + self.wlock = threading.RLock() + + self.stream = torch.cuda.Stream() + self.stream_ptr_int = self.stream.cuda_stream + + except Exception: + self._release_resources() + raise + + logger.debug( + "Initialized HF3FS client with file: %s, size: %s bytes", path, size + ) + + def _release_resources(self) -> None: + """Release all acquired resources safely""" + # iov must be released before ioring and shm + for attr in ("iov_r", "iov_w", "ior_r", "ior_w"): + obj = getattr(self, attr, None) + if obj is not None: + del obj + setattr(self, attr, None) + + for attr in ("shm_r", "shm_w"): + shm = getattr(self, attr, None) + if shm is not None: + try: + shm.close() + except Exception as e: + logger.warning("Failed to close %s: %s", attr, e) + setattr(self, attr, None) + + if self.file is not None: + try: + deregister_fd(self.file) + except Exception as e: + logger.warning("deregister_fd failed: %s", e) + try: + os.close(self.file) + except OSError as e: + logger.warning("os.close failed: %s", e) + self.file = None + + @rsynchronized() + def batch_read(self, offsets: list[int], tensors: list[torch.Tensor]) -> list[int]: + """Read data from the file at specified offsets into tensors. + + Args: + offsets: List of byte offsets to read from + tensors: List of tensors to read data into + + Returns: + List of operation results (0 for success, non-zero for error) + """ + self.check(offsets, tensors) + assert self.ior_r is not None + assert self.iov_r is not None + + # prepare + current = 0 + for offset, tensor in zip(offsets, tensors): + size = tensor.numel() * tensor.itemsize + self.ior_r.prepare( + self.iov_r[current : current + size], True, self.file, offset + ) + current += size + + # submit + ionum = len(offsets) + resv = self.ior_r.submit().wait(min_results=ionum) + + # results + with torch.cuda.stream(self.stream): + hf3fs_utils.read_shm( + self.shm_r_tensor, self.r_pinned, tensors, self.stream_ptr_int + ) + results = [res.result for res in resv] + + return results + + @wsynchronized() + def batch_write( + self, offsets: list[int], tensors: list[torch.Tensor], event: torch.cuda.Event + ) -> list[int]: + """Write data from tensors to the file at specified offsets. + + Args: + offsets: List of byte offsets to write to + tensors: List of tensors containing data to write + + Returns: + List of operation results (0 for success, non-zero for error) + """ + + self.check(offsets, tensors) + assert self.ior_w is not None + assert self.iov_w is not None + + # prepare + with torch.cuda.stream(self.stream): + self.stream.wait_event(event) + hf3fs_utils.write_shm( + tensors, self.shm_w_tensor, self.w_pinned, self.stream_ptr_int + ) + + current = 0 + for offset, tensor in zip(offsets, tensors): + size = tensor.numel() * tensor.itemsize + self.ior_w.prepare( + self.iov_w[current : current + size], False, self.file, offset + ) + current += size + + # submit + ionum = len(offsets) + resv = self.ior_w.submit().wait(min_results=ionum) + + # results + results = [res.result for res in resv] + + return results + + def check(self, offsets: list[int], tensors: list[torch.Tensor]) -> None: + sizes = [t.numel() * t.itemsize for t in tensors] + if any( + [ + len(offsets) > self.entries, + len(offsets) != len(sizes), + any( + offset < 0 or offset + size > self.size + for offset, size in zip(offsets, sizes) + ), + any(size > self.bytes_per_page for size in sizes), + ] + ): + self.close() + raise ValueError("Hf3fsClient.check Failed") + + def get_size(self) -> int: + """Get the total size of the storage file. + + Returns: + Size of the file in bytes + """ + return self.size + + def close(self) -> None: + """Close the client and clean up resources.""" + if self._closed: + return + self._closed = True + self._release_resources() + + def flush(self) -> None: + """Flush any pending writes to disk.""" + if not self._closed and self.file is not None: + os.fsync(self.file) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/hf3fs_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/hf3fs_connector.py new file mode 100644 index 000000000..526375952 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/hf3fs_connector.py @@ -0,0 +1,1195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +HF3FS KV Connector Implementation for vLLM. + +This module implements a KV connector that uses +the 3FS for storing and retrieving KV cache data. + +Key components: +1. HF3FSConnector: Main connector implementation + 2.1 AsyncOperationManager: Manages async save/load operations with background threads + 2.2 HF3FSConnectorMetadata: Container for connector metadata +3. HF3FSMetadataServer: Mini Metadata server for HF3FS connector +4. HF3FSClient: 3FS Client Implementation +""" + +import atexit +import concurrent +import copy +import hashlib +import os +import queue +import signal +import threading +import time +from concurrent.futures import Future +from dataclasses import dataclass +from queue import Empty +from typing import Any, Optional + +import numpy as np +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_metadata_server import ( + Hf3fsGlobalMetadataClient as Hf3fsMetadataClient, +) +from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils import ( + gather_scatter_helper, +) +from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.common import ( + AtomicCounter, + HF3FSConnectorMetadata, + HF3FSRequestMetadata, + LoadBlockInfo, + RequestSchedulingState, +) +from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.gather_scatter_helper import ( # noqa: E501 + CopyBufferAllocator, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, +) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger +from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.utils import create_metric_per_engine +from vllm.v1.request import Request + +HF3FS_AVAILABLE = True +Hf3fsClient = None +try: + from hf3fs_fuse.io import deregister_fd # noqa: F401 + + from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.hf3fs_client import ( + Hf3fsClient as _RealClient, + ) + + Hf3fsClient = _RealClient +except Exception: + HF3FS_AVAILABLE = False + from vllm.distributed.kv_transfer.kv_connector.v1.hf3fs.utils.hf3fs_mock_client import ( # noqa: E501 + Hf3fsClient as _MockClient, + ) + + Hf3fsClient = _MockClient # type: ignore + +# Constants +DEFAULT_MAX_IO_ENTRIES = 8 + +logger = init_logger(__name__) + + +# ============================================================================ +# Async Operation Management +# ============================================================================ + + +class AsyncOperationManager: + """ + Manages async save/load operations with background threads. + """ + + def __init__(self, connector: "HF3FSKVConnector"): + # Store connector reference and extract commonly used attributes + self._connector = connector + self._device = connector._device + self._dtype = connector._dtype + self._shape_per_page = connector._shape_per_page + self._bytes_per_page = connector._bytes_per_page + self._rank = connector._rank + self._numjobs = connector._numjobs + self._max_device_buffer_count = connector._max_device_buffer_count + + # Operation tracking + self._save_futures: dict[str, list[Future]] = {} + self._load_futures: dict[str, Future] = {} + self._pending_finished_requests: set[str] = set() + + # Initialize resources + self._init_cuda_resources() + self._init_worker_threads() + + # Metrics + self.hf3fs_stats = HF3FSKVConnectorStats() + + logger.info("AsyncOperationManager initialized for rank %d", self._rank) + + def _init_cuda_resources(self) -> None: + """Initialize CUDA streams, events and buffer allocators.""" + # CUDA streams for async operations + self._save_stream = torch.cuda.Stream() + self._load_stream = torch.cuda.Stream() + self._save_event = torch.cuda.Event() + + # Buffer allocators for data copying + self._save_buffer_allocator = CopyBufferAllocator( + self._device, + self._dtype, + self._shape_per_page, + self._max_device_buffer_count, + ) + self._load_buffer_allocator = CopyBufferAllocator( + self._device, + self._dtype, + self._shape_per_page, + self._max_device_buffer_count, + ) + + def _init_worker_threads(self) -> None: + """Initialize worker threads and I/O executor.""" + # Thread synchronization + self._stop_event = threading.Event() + self._save_queue: queue.Queue[Any] = queue.Queue() + self._load_queue: queue.Queue[Any] = queue.Queue() + + # I/O thread pool + self._io_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=self._numjobs, + thread_name_prefix=f"HF3FS-Rank{self._rank}", + ) + + # Background worker threads + self._save_thread = threading.Thread(target=self._save_worker, daemon=True) + self._load_thread = threading.Thread(target=self._load_worker, daemon=True) + self._save_thread.start() + self._load_thread.start() + + def submit_save_operation(self, request_id: str, block_ids, block_hashes) -> Future: + """Submit a save operation for async execution.""" + future: Future[Any] = Future() + main_stream_event = torch.cuda.Event() + main_stream_event.record() + task = (request_id, block_ids, block_hashes, future, main_stream_event) + self._save_queue.put(task) + + if request_id not in self._save_futures: + self._save_futures[request_id] = [] + self._save_futures[request_id].append(future) + return future + + def submit_load_operation(self, request_id: str, block_ids, block_hashes) -> Future: + """Submit a load operation for async execution.""" + future: Future[Any] = Future() + task = (request_id, block_ids, block_hashes, future) + self._load_queue.put(task) + self._load_futures[request_id] = future + return future + + def get_finished_operations( + self, finished_req_ids: set[str] + ) -> tuple[set[str], set[str]]: + completed_saves = self._check_completed_saves(finished_req_ids) + completed_loads = self._check_completed_loads() + + if completed_saves or completed_loads: + logger.info( + "HF3FS Connector Completed: %d saves, %d loads operations", + len(completed_saves), + len(completed_loads), + ) + + return completed_saves, completed_loads + + def _check_completed_saves(self, finished_req_ids: set[str]) -> set[str]: + """Check for completed save operations.""" + completed = set() + + # Check pending finished requests first + for request_id in list(self._pending_finished_requests): + if request_id in self._save_futures and self._all_saves_done(request_id): + completed.add(request_id) + self._save_futures.pop(request_id) + self._pending_finished_requests.remove(request_id) + + # Process newly finished requests + for request_id in finished_req_ids: + if request_id in self._save_futures: + if self._all_saves_done(request_id): + completed.add(request_id) + self._save_futures.pop(request_id) + else: + self._pending_finished_requests.add(request_id) + else: + completed.add(request_id) + + return completed + + def _check_completed_loads(self) -> set[str]: + """Check for completed load operations.""" + completed = set() + for request_id in list(self._load_futures): + if self._load_futures[request_id].done(): + completed.add(request_id) + self._load_futures.pop(request_id) + return completed + + def _all_saves_done(self, request_id: str) -> bool: + """Check if all save operations for a request are completed.""" + return all(future.done() for future in self._save_futures[request_id]) + + def _save_worker(self) -> None: + """Background worker for handling save operations.""" + torch.accelerator.set_device_index(self._device) + while not self._stop_event.is_set(): + try: + task = self._save_queue.get(block=True, timeout=1) + self._handle_save_task(task) + except Empty: + continue + except Exception as e: + logger.error("Save worker error: %s", e) + + def _load_worker(self) -> None: + """Background worker for handling load operations.""" + torch.accelerator.set_device_index(self._device) + while not self._stop_event.is_set(): + try: + task = self._load_queue.get(block=True, timeout=1) + self._handle_load_task(task) + except Empty: + continue + except Exception as e: + logger.error("Load worker error: %s", e) + + def _handle_save_task(self, task) -> None: + """Handle individual save task with proper stream synchronization.""" + request_id, block_ids, block_hashes, future, main_stream_event = task + start_time = time.perf_counter() + buffers = None + try: + # Step1: Allocate storage pages + key_pairs = [(hash_val, "") for hash_val in block_hashes] + allocation_results = ( + self._connector._metadata_client.allocate_pages_for_keys( + self._rank, key_pairs + ) + ) + + if any(result[1] < 0 for result in allocation_results): + return self._fail_task( + "Saved", "Page allocation failed", request_id, future + ) + + page_indices = [result[1] for result in allocation_results] + offsets = [idx * self._bytes_per_page for idx in page_indices] + + # Step2: Allocate buffers and gather KV cache data + buffers = self._save_buffer_allocator.alloc_buffer(len(block_ids)) + if buffers is None: + return self._fail_task( + "Saved", + f"Buffer allocation failed for {len(block_ids)} blocks", + request_id, + future, + ) + + # Synchronize streams and gather data + with torch.cuda.stream(self._save_stream): + self._save_stream.wait_event(main_stream_event) # Wait for main stream + self._connector._gather_or_scatter_kv_caches( + block_ids, buffers, "gather" + ) + + save_stream_event = torch.cuda.Event() + save_stream_event.record(self._save_stream) # Record gather completion + + # Step3: Write data in batches + write_futures = [] + for i in range(0, len(offsets), DEFAULT_MAX_IO_ENTRIES): + batch_offsets = offsets[i : i + DEFAULT_MAX_IO_ENTRIES] + batch_buffers = buffers[i : i + DEFAULT_MAX_IO_ENTRIES] + client = self._connector._clients[self._connector._ac.next()] + write_future = self._io_executor.submit( + client.batch_write, batch_offsets, batch_buffers, save_stream_event + ) + write_futures.append(write_future) + + # Check write results + write_success = all( + result == self._bytes_per_page + for write_future in write_futures + for result in write_future.result() + ) + + # Step4: Confirm writes to metadata server + if write_success: + written_keys = list(zip(block_hashes, page_indices)) + self._connector._metadata_client.confirm_write_for_keys( + self._rank, written_keys, [] + ) + self._save_buffer_allocator.free_buffer(buffers) + return self._succeed_task( + "Saved", start_time, request_id, len(block_ids), future + ) + else: + self._connector._metadata_client.confirm_write_for_keys( + self._rank, [], page_indices + ) + self._save_buffer_allocator.free_buffer(buffers) + return self._fail_task( + "Saved", "Write operation failed", request_id, future + ) + + except Exception as e: + if buffers is not None: + self._save_buffer_allocator.free_buffer(buffers) + return self._fail_task( + "Saved", f"Task execution error: {e}", request_id, future + ) + + def _handle_load_task(self, task) -> None: + """Handle individual load task.""" + request_id, block_ids, block_hashes, future = task + start_time = time.perf_counter() + buffers = None + try: + # Step1: Get block locations from metadata server + page_indices = self._connector._metadata_client.get_key_locations( + self._rank, block_hashes + ) + + if any(idx is None for idx in page_indices): + return self._fail_task("Loaded", "Blocks not found", request_id, future) + + # Allocate read buffer + buffers = self._load_buffer_allocator.alloc_buffer(len(block_ids)) + if buffers is None: + return self._fail_task( + "Loaded", + f"Buffer allocation failed for {len(block_ids)} blocks", + request_id, + future, + ) + + # Step2: Read data in batches + offsets = [idx * self._bytes_per_page for idx in page_indices] + read_futures = [] + for i in range(0, len(offsets), DEFAULT_MAX_IO_ENTRIES): + batch_offsets = offsets[i : i + DEFAULT_MAX_IO_ENTRIES] + batch_buffers = buffers[i : i + DEFAULT_MAX_IO_ENTRIES] + client = self._connector._clients[self._connector._ac.next()] + read_future = self._io_executor.submit( + client.batch_read, batch_offsets, batch_buffers + ) + read_futures.append(read_future) + + # Check read results + read_success = all( + result == self._bytes_per_page + for read_future in read_futures + for result in read_future.result() + ) + + if not read_success: + self._load_buffer_allocator.free_buffer(buffers) + return self._fail_task( + "Loaded", "Read operation failed", request_id, future + ) + + # Step3: Scatter data back to KV cache + with torch.cuda.stream(self._load_stream): + self._connector._gather_or_scatter_kv_caches( + block_ids, buffers, "scatter" + ) + + self._load_stream.synchronize() + self._load_buffer_allocator.free_buffer(buffers) + return self._succeed_task( + "Loaded", start_time, request_id, len(block_ids), future + ) + + except Exception as e: + if buffers is not None: + self._load_buffer_allocator.free_buffer(buffers) + return self._fail_task( + "Loaded", f"Task execution error: {e}", request_id, future + ) + + def _fail_task( + self, operation: str, error_msg: str, request_id: str, future: Future + ) -> None: + """Helper to fail task with error logging.""" + logger.error( + "%s for %s request %s", + error_msg, + operation, + request_id, + ) + self.hf3fs_stats.record_failed_task_count(operation) + future.set_result(False) + + def _succeed_task( + self, + operation: str, + start_time: float, + request_id: str, + block_count: int, + future: Future, + ) -> None: + """Helper to succeed task with logging.""" + duration = time.perf_counter() - start_time + logger.info( + "%s %s: %d blocks in %.2fs", + operation, + request_id, + block_count, + duration, + ) + self.hf3fs_stats.record_success_task_duration(operation, duration) + future.set_result(True) + + def shutdown(self) -> None: + """Clean shutdown of all background threads and resources.""" + self._stop_event.set() + self._save_thread.join() + self._load_thread.join() + self._io_executor.shutdown(wait=True) + logger.info("AsyncOperationManager shutdown completed") + + +# ============================================================================ +# HF3FS Connector +# ============================================================================ + + +class HF3FSKVConnector(KVConnectorBase_V1): + """HF3FS KV Connector implementation.""" + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) + + # Core configuration + self._vllm_config = vllm_config + self._role = role + self._block_size = vllm_config.cache_config.block_size + self._use_mla = vllm_config.model_config.use_mla + self._model_config = vllm_config.model_config + + logger.info("Using MLA: %s", self._use_mla) + + # HF3FS configuration + kv_config = vllm_config.kv_transfer_config + assert kv_config is not None + + self._storage_path = kv_config.get_from_extra_config( + "hf3fs_storage_path", "/vllm-workspace/mnt/hf3fs" + ) + self._metadata_server_url = kv_config.get_from_extra_config( + "hf3fs_metadata_server_url", "http://localhost:18000" + ) + self._file_size = kv_config.get_from_extra_config( + "hf3fs_file_size", 1024 * 1024 * 1024 + ) + self._numjobs = kv_config.get_from_extra_config("hf3fs_client_numjobs", 16) + self._max_device_buffer_count = kv_config.get_from_extra_config( + "hf3fs_max_device_buffer_count", 128 + ) + self._max_device_buffer_count = max( + self._max_device_buffer_count, self._numjobs * DEFAULT_MAX_IO_ENTRIES + ) + + if self._role == KVConnectorRole.SCHEDULER: + self._scheduling_states: dict[str, RequestSchedulingState] = {} + self._metadata_client = Hf3fsMetadataClient() + self._metadata_client.initialize(0, role="scheduler") + + atexit.register(self.close) + signal.signal(signal.SIGINT, lambda sig, frame: self.close()) + signal.signal(signal.SIGTERM, lambda sig, frame: self.close()) + signal.signal(signal.SIGQUIT, lambda sig, frame: self.close()) + + logger.info( + "HF3FSKVConnector initialized: path=%s, role=%s", + self._storage_path, + self._role.name, + ) + + ############################################################ + # Worker Side Methods + ############################################################ + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None: + self._kv_caches = kv_caches + self._setup_kv_cache_config() + self._setup_storage_clients() + self._async_manager = AsyncOperationManager(self) + + def _setup_kv_cache_config(self): + first_cache = next(iter(self._kv_caches.values())) + self._device = first_cache.device + self._dtype = first_cache.dtype + element_size = first_cache.element_size() + + if self._use_mla: + assert len(first_cache.shape) == 3, "MLA format should have 3 dimensions" + # MLA format: [num_blocks, block_size, head_size] + num_blocks, block_size, head_size = first_cache.shape + num_heads = 1 + else: + # MHA format: [2, num_blocks, block_size, num_heads, head_size] + _, num_blocks, block_size, num_heads, head_size = first_cache.shape + + self._local_total_tokens = num_blocks * block_size + self._local_block_size = block_size + + if self._use_mla: + layer_block_size = block_size * head_size * element_size + self._bytes_per_page = layer_block_size * len(self._kv_caches) + self._shape_per_page = [ + len(self._kv_caches), + block_size, + head_size, + ] + else: + layer_block_size = 2 * block_size * num_heads * head_size * element_size + self._bytes_per_page = layer_block_size * len(self._kv_caches) + self._shape_per_page = [ + len(self._kv_caches), + 2, + block_size, + num_heads * head_size, + ] + + self._kvcache_ptrs = torch.tensor( + [cache.data_ptr() for cache in self._kv_caches.values()], + dtype=torch.int64, + device=self._device, + ) + + def _setup_storage_clients(self): + os.makedirs(self._storage_path, exist_ok=True) + + self._rank = get_tensor_model_parallel_rank() + file_path = os.path.join( + self._storage_path, f"hf3fs_vllm_data_file_{self._rank}" + ) + + try: + # Initialize HF3FS clients + self._ac = AtomicCounter(self._numjobs) + assert Hf3fsClient is not None + self._clients = [ + Hf3fsClient( + path=file_path, + size=self._file_size, + bytes_per_page=self._bytes_per_page, + entries=DEFAULT_MAX_IO_ENTRIES, + ) + for _ in range(self._numjobs) + ] + + # Initialize metadata client + num_pages = self._file_size // self._bytes_per_page + self._metadata_client = Hf3fsMetadataClient() + self._metadata_client.initialize(self._rank, num_pages, role="worker") + except Exception as e: + logger.error("HF3FS client initialization failed: %s", e) + raise + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + """HF3FSConnector does not do layerwise saving.""" + pass + + def wait_for_save(self) -> None: + metadata = self._get_connector_metadata() + if not isinstance(metadata, HF3FSConnectorMetadata): + logger.error("Invalid metadata type: %s", type(metadata)) + return + + for request in metadata.requests: + if request.save_block_op is None: + continue + + skip_blocks = request.save_block_op.skip_leading_blocks + block_hashes = self._generate_block_hashes(request.token_ids, skip_blocks) + block_ids = request.block_ids[skip_blocks : skip_blocks + len(block_hashes)] + + for i in range(0, len(block_ids), self._max_device_buffer_count): + batch_block_ids = block_ids[i : i + self._max_device_buffer_count] + batch_block_hashes = block_hashes[i : i + self._max_device_buffer_count] + self._async_manager.submit_save_operation( + request.request_id, batch_block_ids, batch_block_hashes + ) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + metadata = self._get_connector_metadata() + if not isinstance(metadata, HF3FSConnectorMetadata): + logger.error("Invalid metadata type for loading") + return + + for request in metadata.requests: + if request.load_block_op is None: + continue + + load_op = request.load_block_op + block_ids = request.block_ids[: load_op.num_blocks_to_load] + block_hashes = self._generate_block_hashes( + request.token_ids, load_op.num_computed_blocks, len(block_ids) + ) + + for i in range(0, len(block_ids), self._max_device_buffer_count): + batch_block_ids = block_ids[i : i + self._max_device_buffer_count] + batch_block_hashes = block_hashes[i : i + self._max_device_buffer_count] + self._async_manager.submit_load_operation( + request.request_id, batch_block_ids, batch_block_hashes + ) + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + return self._async_manager.get_finished_operations(finished_req_ids) + + def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]: + """ + Get the KV connector stats collected during the last interval. + """ + # Clear stats for next iteration + if ( + hasattr(self, "_async_manager") + and not self._async_manager.hf3fs_stats.is_empty() + ): + return self._async_manager.hf3fs_stats.clone_and_reset() + return None + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + return True, None + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: + """Get number of new tokens that can be loaded from external cache.""" + try: + state = self._get_or_create_scheduling_state(request.request_id) + state.request = request + assert request.prompt_token_ids is not None + + num_tokens_to_check = self._align_to_block_size( + len(request.prompt_token_ids) - 1 + ) + + if num_tokens_to_check <= num_computed_tokens: + state.load_op = LoadBlockInfo( + num_computed_blocks=num_computed_tokens // self._block_size, + num_blocks_to_load=0, + need_fetch_block_ids=[], + ) + return 0, False + + token_ids_to_check = request.prompt_token_ids[:num_tokens_to_check] + block_hashes = self._generate_block_hashes(token_ids_to_check, 0) + + # Check existence + exists_results = self._metadata_client.batch_key_exists(block_hashes) + + # Count consecutive matches + matched_blocks = next( + (i for i, exists in enumerate(exists_results) if not exists), + len(exists_results), + ) + matched_tokens = matched_blocks * self._block_size + new_hit_tokens = max(0, matched_tokens - num_computed_tokens) + + # Store load operation + state.load_op = LoadBlockInfo( + num_computed_blocks=num_computed_tokens // self._block_size, + num_blocks_to_load=new_hit_tokens // self._block_size, + need_fetch_block_ids=[], + ) + + logger.info( + ( + "Token matching for %s: " + "%d matched (%d blocks), " + "%d new hits, " + "prompt len %d" + ), + request.request_id, + matched_tokens, + matched_blocks, + new_hit_tokens, + len(request.prompt_token_ids), + ) + return new_hit_tokens, new_hit_tokens > 0 + + except Exception as e: + logger.error( + "Error calculating matches for request %s: %s", request.request_id, e + ) + return 0, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ) -> None: + """Update state after block allocation.""" + state = self._get_or_create_scheduling_state(request.request_id) + state.request = request + + if num_external_tokens <= 0 or not state.needs_loading(): + return + + # Validate block allocation + assert state.load_op is not None + expected_blocks = state.load_op.num_blocks_to_load + actual_blocks = num_external_tokens // self._block_size + assert actual_blocks == expected_blocks, ( + f"Block count mismatch for {request.request_id}: " + f"expected {expected_blocks}, got {actual_blocks}" + ) + + # Update load operation with allocated block IDs + if actual_blocks > 0: + local_block_ids = blocks.get_unhashed_block_ids() + state.load_op.need_fetch_block_ids.extend(local_block_ids) + state.phase = "WAITING_TO_LOAD" + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + """Build connector metadata for scheduling step.""" + metadata = HF3FSConnectorMetadata() + + for request_id in scheduler_output.finished_req_ids: + self._scheduling_states.pop(request_id, None) + + # Process requests by phase + self._process_waiting_to_load_requests(metadata) + self._process_new_requests(scheduler_output, metadata) + self._process_cached_requests(scheduler_output, metadata) + + return metadata + + def _process_waiting_to_load_requests( + self, metadata: HF3FSConnectorMetadata + ) -> None: + """Process requests waiting to load.""" + for state in list(self._scheduling_states.values()): + if not state.is_ready_to_load(): + continue + assert state.load_op is not None + assert ( + state.request is not None and state.request.prompt_token_ids is not None + ) + # Create load request metadata + num_cached_blocks = ( + state.load_op.num_computed_blocks + state.load_op.num_blocks_to_load + ) + num_tokens_to_compute = num_cached_blocks * self._block_size + + # Initialize token_ids and allocated_block_ids for loading + state.token_ids = state.request.prompt_token_ids[ + :num_tokens_to_compute + ].copy() + state.allocated_block_ids = state.load_op.need_fetch_block_ids.copy() + + request_metadata = HF3FSRequestMetadata.from_scheduling_state( + state, self._block_size, state.load_op, num_cached_blocks + ) + + if request_metadata: + metadata.add_request(request_metadata) + state.phase = "ACTIVE" + + def _process_new_requests( + self, scheduler_output: SchedulerOutput, metadata: HF3FSConnectorMetadata + ) -> None: + """Process new requests.""" + for request in scheduler_output.scheduled_new_reqs: + state = self._get_or_create_scheduling_state(request.req_id) + + # Calculate tokens to compute + num_tokens_to_compute = ( + request.num_computed_tokens + + scheduler_output.num_scheduled_tokens[request.req_id] + ) + self._initialize_state_from_new_request( + state, request, num_tokens_to_compute + ) + + # Create save metadata (skip cached blocks if any) + num_cached_blocks = None + if state.load_op: + num_cached_blocks = ( + state.load_op.num_computed_blocks + state.load_op.num_blocks_to_load + ) + + request_metadata = HF3FSRequestMetadata.from_scheduling_state( + state, self._block_size, None, num_cached_blocks + ) + + if request_metadata: + metadata.add_request(request_metadata) + state.phase = "ACTIVE" + + def _process_cached_requests( + self, scheduler_output: SchedulerOutput, metadata: HF3FSConnectorMetadata + ) -> None: + """Process cached requests.""" + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, request_id in enumerate(cached_reqs.req_ids): + state = self._get_or_create_scheduling_state(request_id) + assert state.request is not None + + # Update with new tokens and blocks + num_new_tokens = scheduler_output.num_scheduled_tokens[request_id] + num_current_tokens = len(state.token_ids) + new_token_ids = state.request.all_token_ids[ + num_current_tokens : num_current_tokens + num_new_tokens + ] + new_block_ids = cached_reqs.new_block_ids[i] + + state.update_tokens_and_blocks(new_token_ids, new_block_ids) + + # Create save metadata + request_metadata = HF3FSRequestMetadata.from_scheduling_state( + state, self._block_size, None + ) + + if request_metadata: + metadata.add_request(request_metadata) + + @classmethod + def build_kv_connector_stats( + cls, data: dict[str, Any] | None = None + ) -> Optional["KVConnectorStats"]: + """ + KVConnectorStats resolution method. This method allows dynamically + registered connectors to return their own KVConnectorStats object, + which can implement custom aggregation logic on the data dict. + """ + return ( + HF3FSKVConnectorStats(data=data) + if data is not None + else HF3FSKVConnectorStats() + ) + + @classmethod + def build_prom_metrics( + cls, + vllm_config: VllmConfig, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[object]], + ) -> KVConnectorPromMetrics: + return HF3FSPromMetrics( + vllm_config, metric_types, labelnames, per_engine_labelvalues + ) + + def close(self) -> None: + try: + if hasattr(self, "_async_manager"): + self._async_manager.shutdown() + + if hasattr(self, "_clients"): + for client in self._clients: + client.close() + logger.info("HF3FS clients closed") + except Exception as e: + logger.error("Connector shutdown error: %s", e) + + ############################################################ + # Utility Methods + ############################################################ + + def _get_or_create_scheduling_state( + self, request_id: str + ) -> RequestSchedulingState: + """Get existing or create new scheduling state.""" + if request_id not in self._scheduling_states: + self._scheduling_states[request_id] = RequestSchedulingState( + request_id=request_id + ) + return self._scheduling_states[request_id] + + def _initialize_state_from_new_request( + self, state: RequestSchedulingState, request, num_tokens_to_compute: int + ) -> None: + """Initialize state from new request data.""" + # Handle different block_ids formats in vLLM 0.9.0+ + if isinstance(request.block_ids[0], list): + unfolded_block_ids = request.block_ids[0].copy() + else: + unfolded_block_ids = request.block_ids.copy() + + state.token_ids = request.prompt_token_ids[:num_tokens_to_compute].copy() + state.allocated_block_ids = unfolded_block_ids + state.num_saved_blocks = 0 + + def _generate_block_hashes( + self, + token_ids: list[int], + start_block_id: int, + max_blocks_count: int | None = None, + ) -> list[str]: + """Generate block hashes for token sequence.""" + block_hashes = [] + previous_hash = "" + + for start_idx in range(0, len(token_ids), self._block_size): + if start_idx + self._block_size > len(token_ids): + break + + end_idx = start_idx + self._block_size + block_hash = self._compute_prefix_hash( + token_ids[start_idx:end_idx], previous_hash + ) + + block_index = start_idx // self._block_size + if block_index >= start_block_id: + block_hashes.append(block_hash) + + if max_blocks_count and len(block_hashes) >= max_blocks_count: + break + previous_hash = block_hash + + return block_hashes + + def _gather_or_scatter_kv_caches( + self, block_ids: list[int], block_buffers, operation: str + ): + for buffer_tensor, block_id in zip(block_buffers, block_ids): + start_idx = block_id * self._local_block_size + token_indices = list(range(start_idx, start_idx + self._local_block_size)) + if operation == "gather": + gather_scatter_helper.gather_kv_caches( + self._kvcache_ptrs, + self._local_total_tokens, + buffer_tensor, + token_indices, + is_mla=self._use_mla, + ) + else: + gather_scatter_helper.scatter_kv_caches( + self._kvcache_ptrs, + self._local_total_tokens, + buffer_tensor, + token_indices, + is_mla=self._use_mla, + ) + + def _compute_prefix_hash( + self, token_ids: list[int], previous_hash: str = "" + ) -> str: + """Compute prefix hash for token block.""" + combined_string = f"{previous_hash}_{token_ids}" + return hashlib.md5(combined_string.encode()).hexdigest() + + def _align_to_block_size(self, num_tokens: int) -> int: + """Align token count to block size.""" + return (num_tokens // self._block_size) * self._block_size + + +@dataclass +class HF3FSKVConnectorStats(KVConnectorStats): + """Container for transfer performance metrics""" + + def __post_init__(self): + if not self.data: + # Empty container init, no data is passed in. + self.reset() + + def reset(self): + # Must be serializable + self.data: dict[str, Any] = { + "save_duration": [], + "load_duration": [], + "num_failed_save": 0, + "num_failed_load": 0, + "num_transfer_task": 0, + } + + def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats": + if not other.is_empty(): + for k, v in other.data.items(): + accumulator = self.data[k] + if isinstance(accumulator, list): + accumulator.extend(v) + else: # int + self.data[k] += v + return self + + def reduce(self) -> dict[str, int | float]: + # Compute compact representative stats suitable for CLI logging + if self.is_empty(): + return { + "Num transfers task": 0, + "Num save task success": 0, + "Num save task failed": 0, + "Num load task success": 0, + "Num load task failed": 0, + "Avg save duration (ms)": 0, + "P90 save duration (ms)": 0, + "Avg load duration (ms)": 0, + "P90 load duration (ms)": 0, + } + num_success_save = len(self.data["save_duration"] or []) + num_success_load = len(self.data["load_duration"] or []) + num_failed_save = self.data["num_failed_save"] + num_failed_load = self.data["num_failed_load"] + if num_success_save == 0: + save_duration = np.zeros(1) + else: + save_duration = np.asarray(self.data["save_duration"]) + if num_success_load == 0: + load_duration = np.zeros(1) + else: + load_duration = np.asarray(self.data["load_duration"]) + + return { + "Num transfers task": self.data["num_transfer_task"], + "Num save task success": num_success_save, + "Num save task failed": num_failed_save, + "Num load task success": num_success_load, + "Num load task failed": num_failed_load, + "Avg save duration (ms)": round(save_duration.mean() * 1e3, 3), + "P90 save duration (ms)": round(np.percentile(save_duration, 90) * 1e3, 3), + "Avg load duration (ms)": round(load_duration.mean() * 1e3, 3), + "P90 load duration (ms)": round(np.percentile(load_duration, 90) * 1e3, 3), + } + + def is_empty(self) -> bool: + return self.data["num_transfer_task"] == 0 + + def record_success_task_duration(self, operation, duration): + if operation == "Saved": + self.data["save_duration"].append(duration) + elif operation == "Loaded": + self.data["load_duration"].append(duration) + self.data["num_transfer_task"] += 1 + + def record_failed_task_count(self, operation): + if operation == "Saved": + self.data["num_failed_save"] += 1 + elif operation == "Loaded": + self.data["num_failed_load"] += 1 + self.data["num_transfer_task"] += 1 + + def clone_and_reset(self): + old = copy.copy(self) + self.reset() + return old + + +class HF3FSPromMetrics(KVConnectorPromMetrics): + def __init__( + self, + vllm_config: VllmConfig, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[object]], + ): + super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues) + buckets = [ + 0.001, + 0.005, + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.2, + 0.3, + 0.5, + 0.75, + 1.0, + 5.0, + ] + hf3fs_save_duration = self._histogram_cls( + name="vllm:hf3fs_save_duration_seconds", + documentation="Histogram of save duration for HF3FSKVConnector.", + buckets=buckets, + labelnames=labelnames, + ) + self.hf3fs_save_duration = create_metric_per_engine( + hf3fs_save_duration, self.per_engine_labelvalues + ) + + hf3fs_load_duration = self._histogram_cls( + name="vllm:hf3fs_load_duration_seconds", + documentation="Histogram of load duration for HF3FSKVConnector.", + buckets=buckets, + labelnames=labelnames, + ) + self.hf3fs_load_duration = create_metric_per_engine( + hf3fs_load_duration, self.per_engine_labelvalues + ) + + hf3fs_num_failed_save = self._counter_cls( + name="vllm:hf3fs_num_failed_save", + documentation="Number of failed HF3FS KV save.", + labelnames=labelnames, + ) + self.hf3fs_num_failed_save = create_metric_per_engine( + hf3fs_num_failed_save, self.per_engine_labelvalues + ) + + hf3fs_num_failed_load = self._counter_cls( + name="vllm:hf3fs_num_failed_load", + documentation="Number of failed HF3FS KV load.", + labelnames=labelnames, + ) + self.hf3fs_num_failed_load = create_metric_per_engine( + hf3fs_num_failed_load, self.per_engine_labelvalues + ) + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + for prom_obj, list_item_key in zip( + [ + self.hf3fs_save_duration, + self.hf3fs_load_duration, + ], + [ + "save_duration", + "load_duration", + ], + ): + for list_item in transfer_stats_data[list_item_key]: + prom_obj[engine_idx].observe(list_item) + for counter_obj, counter_item_key in zip( + [ + self.hf3fs_num_failed_save, + self.hf3fs_num_failed_load, + ], + [ + "num_failed_save", + "num_failed_load", + ], + ): + counter_obj[engine_idx].inc(transfer_stats_data[counter_item_key]) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/hf3fs_metadata_server.py b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/hf3fs_metadata_server.py new file mode 100644 index 000000000..72792e5eb --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/hf3fs_metadata_server.py @@ -0,0 +1,530 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +HF3FS Metadata Server with key-based organization. +""" + +import argparse +import logging +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass + +try: + import orjson + + HAS_ORJSON = True +except ImportError: + import json as orjson # type: ignore + + HAS_ORJSON = False + +import requests +from fastapi import FastAPI, HTTPException, Request, Response +from fastapi.responses import ORJSONResponse +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +@dataclass +class RankFileMetadata: + """Manages file page allocation for a single rank.""" + + rank_id: int + num_pages: int + free_pages: list[int] + + def allocate_pages(self, num_pages: int) -> list[int]: + """Allocate specified number of free pages.""" + if len(self.free_pages) < num_pages: + return [] + + allocated = self.free_pages[:num_pages] + self.free_pages = self.free_pages[num_pages:] + return allocated + + def release_pages(self, page_indices: list[int]) -> None: + """Release pages back to free pool.""" + for page_idx in page_indices: + if page_idx not in self.free_pages: + self.free_pages.append(page_idx) + + def get_free_page_count(self) -> int: + """Get current number of free pages.""" + return len(self.free_pages) + + +@dataclass +class KeyMetadata: + """Manages metadata for a single key across multiple ranks.""" + + key: str + rank_to_page: dict[int, int] # rank -> allocated page index + tp_world_size: int + + def add_rank_page(self, rank: int, page_index: int) -> None: + """Add page allocation for a specific rank.""" + self.rank_to_page[rank] = page_index + + def get_all_pages(self) -> list[tuple[int, int]]: + """Get all (rank, page) pairs for this key.""" + return [(rank, page) for rank, page in self.rank_to_page.items()] + + def get_rank_page(self, rank: int) -> int | None: + """Get page index for a specific rank.""" + return self.rank_to_page.get(rank) + + def is_complete(self) -> bool: + """Check if all ranks in the TP world have allocated pages.""" + return len(self.rank_to_page) == self.tp_world_size + + +class GlobalMetadataState: + """Manages global metadata state across all ranks and keys.""" + + def __init__(self): + self.global_lock = threading.RLock() + self.rank_metadata: dict[int, RankFileMetadata] = {} + self.key_metadata: dict[str, KeyMetadata] = {} + + def clear(self) -> None: + """Clear all metadata state.""" + with self.global_lock: + self.rank_metadata.clear() + self.key_metadata.clear() + logger.info("Cleared all metadata state") + + def initialize_rank(self, rank: int, num_pages: int) -> None: + """Initialize a new rank with specified number of pages.""" + with self.global_lock: + if rank not in self.rank_metadata: + self.rank_metadata[rank] = RankFileMetadata( + rank, num_pages, list(range(num_pages)) + ) + logger.info("Initialized rank %s with %s pages", rank, num_pages) + + def allocate_pages_for_keys( + self, rank: int, keys: list[tuple[str, str]] + ) -> dict[str, int]: + """Allocate one page for each key on the specified rank. + + Args: + rank: Rank ID to allocate pages on + keys: List of keys to allocate pages for + + Returns: + Dictionary mapping key -> allocated page index + """ + with self.global_lock: + if rank not in self.rank_metadata: + raise ValueError(f"Rank {rank} not initialized") + + # Batch allocate pages for all keys + num_pages_needed = len(keys) + allocated_pages = self.rank_metadata[rank].allocate_pages(num_pages_needed) + + if len(allocated_pages) < num_pages_needed: + logger.warning( + "Rank %s only allocated %s pages for %s keys", + rank, + len(allocated_pages), + num_pages_needed, + ) + + allocation_results = {} + for i, (key, prefix_key) in enumerate(keys): + if key in self.key_metadata: + key_meta = self.key_metadata[key] + if key_meta.is_complete() and rank in key_meta.rank_to_page: + # key is already fully written, reuse the existing page + # and release the allocated pages back to the free pool. + if i < len(allocated_pages): + self.rank_metadata[rank].release_pages([allocated_pages[i]]) + allocation_results[key] = key_meta.rank_to_page[rank] + continue + + if i < len(allocated_pages): + allocation_results[key] = allocated_pages[i] + else: + allocation_results[key] = -1 # No pages available + + return allocation_results + + def confirm_write_for_keys( + self, + rank: int, + key_confirmations: list[tuple[str, int]], + pages_to_release: list[int] | None = None, + ) -> None: + """Confirm write operations for keys and update metadata. + + Args: + rank: Rank ID that confirmed the writes + key_confirmations: List of (key, page_index) tuples + pages_to_release: List of page indices to release back to free pool + """ + with self.global_lock: + # Confirm successful writes + for key, page_index in key_confirmations: + if key not in self.key_metadata: + # Need to determine tp_world_size from rank_metadata + tp_world_size = len(self.rank_metadata) + self.key_metadata[key] = KeyMetadata(key, {}, tp_world_size) + + # Add confirmed page to key metadata + self.key_metadata[key].add_rank_page(rank, page_index) + + # Release specified pages back to free pool + if pages_to_release: + self.rank_metadata[rank].release_pages(pages_to_release) + logger.debug( + "Released %s pages on rank %s: %s", + len(pages_to_release), + rank, + pages_to_release, + ) + + def batch_key_exists(self, keys: list[str]) -> list[bool]: + """Check if keys exist in metadata and all ranks have confirmed writes. + + Args: + keys: List of keys to check + + Returns: + List of boolean values indicating key existence and completion + """ + with self.global_lock: + results = [] + for key in keys: + if key not in self.key_metadata: + results.append(False) + else: + # Check if all ranks in the TP world have confirmed writes + key_meta = self.key_metadata[key] + results.append(key_meta.is_complete()) + return results + + def get_key_locations(self, rank: int, keys: list[str]) -> list[int | None]: + """Get page indices for keys on a specific rank. + + Args: + rank: Rank ID to query + keys: List of keys to look up + + Returns: + List of page indices in the same order as input keys (None if key not found) + """ + with self.global_lock: + if rank not in self.rank_metadata: + raise ValueError(f"Rank {rank} not initialized") + + results = [] + for key in keys: + if key in self.key_metadata: + key_meta = self.key_metadata[key] + if key_meta.is_complete(): + page_index = key_meta.get_rank_page(rank) + else: + page_index = None + + results.append(page_index) + else: + results.append(None) + + return results + + +class Hf3fsMetadataServer: + """HF3FS Metadata Server with improved key-based organization.""" + + def __init__(self, persistence_path: str | None = None, save_interval: int = 60): + self.state = GlobalMetadataState() + if HAS_ORJSON: + self.app = FastAPI(default_response_class=ORJSONResponse) + else: + self.app = FastAPI() + self._setup_routes() + + async def _read_json(self, request: Request) -> dict: + """Parse request JSON using orjson if available.""" + body = await request.body() + return orjson.loads(body) + + def _json_response(self, content: dict): + """Return ORJSONResponse when available to bypass jsonable_encoder.""" + if HAS_ORJSON: + return ORJSONResponse(content) + else: + return content + + def _setup_routes(self): + """Setup FastAPI routes for new API design.""" + self.app.post("/rank/{rank}/initialize")(self.initialize_rank) + self.app.post("/keys/batch_allocate")(self.batch_allocate_pages_for_keys) + self.app.post("/keys/confirm_write")(self.confirm_write_for_keys) + self.app.post("/keys/batch_exists")(self.batch_key_exists) + self.app.post("/keys/get_locations")(self.get_key_locations) + self.app.post("/clear")(self.clear) + + async def initialize_rank(self, rank: int, request: Request): + """Initialize a rank with specified number of pages.""" + data = await self._read_json(request) + role = data.get("role", "worker") + num_pages = data.get("num_pages", 0) + + if role == "scheduler": + return self._json_response( + {"message": "Scheduler role does not require initialization"} + ) + + if role == "worker" and num_pages > 0: + self.state.initialize_rank(rank, num_pages) + return self._json_response( + {"message": f"Rank {rank} initialized with {num_pages} pages"} + ) + else: + raise HTTPException( + status_code=400, detail="Invalid initialization parameters" + ) + + async def batch_allocate_pages_for_keys(self, request: Request): + """Allocate one page for each key on a specific rank.""" + data = await self._read_json(request) + rank = data.get("rank") + keys = data.get("keys", []) + + # Validate input format + if rank is None or not isinstance(keys, list): + raise HTTPException( + status_code=400, detail="Invalid request format: need 'rank' and 'keys'" + ) + + try: + # Perform allocation + results = self.state.allocate_pages_for_keys(rank, keys) + + # Convert results to response format + response = {"rank": rank, "results": list(results.items())} + return self._json_response(response) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Allocation failed: {str(e)}" + ) from e + + async def confirm_write_for_keys(self, request: Request): + """Confirm write operations for keys.""" + data = await self._read_json(request) + rank = data.get("rank") + confirmations = data.get("confirmations", []) + pages_to_release = data.get("pages_to_release", []) + + # Validate input format + if rank is None or not isinstance(confirmations, list): + raise HTTPException( + status_code=400, + detail="Invalid request format: need 'rank' and 'confirmations'", + ) + + try: + self.state.confirm_write_for_keys(rank, confirmations, pages_to_release) + + return Response(status_code=204) + + except Exception as e: + logger.error("Confirm write for keys failed: %s", e) + raise HTTPException( + status_code=500, detail=f"Confirmation failed: {str(e)}" + ) from e + + async def batch_key_exists(self, request: Request): + """Check if multiple keys exist in metadata.""" + data = await self._read_json(request) + keys = data.get("keys", []) + + if not isinstance(keys, list): + raise HTTPException(status_code=400, detail="Invalid keys format") + + try: + exists_results = self.state.batch_key_exists(keys) + return self._json_response({"exists": exists_results}) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Key existence check failed: {str(e)}" + ) from e + + async def get_key_locations(self, request: Request): + """Get page indices for keys on a specific rank.""" + data = await self._read_json(request) + rank = data.get("rank") + keys = data.get("keys", []) + + # Validate input format + if rank is None or not isinstance(keys, list): + raise HTTPException( + status_code=400, detail="Invalid request format: need 'rank' and 'keys'" + ) + + try: + # Get key locations + locations = self.state.get_key_locations(rank, keys) + return self._json_response({"locations": locations}) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to get key locations: {str(e)}" + ) from e + + async def clear(self, request: Request): + """Clear the metadata server.""" + self.state.clear() + return Response(status_code=204) + + def run(self, host: str = "0.0.0.0", port: int = 18000): + """Run the metadata server.""" + import uvicorn + + logger.info("Starting improved metadata server on http://%s:%s", host, port) + uvicorn.run(self.app, host=host, port=port) + + +# --- Client implementation --- +class Hf3fsMetadataInterface(ABC): + """Interface for HF3FS metadata operations.""" + + @abstractmethod + def initialize(self, rank: int, num_pages: int = 0, role: str = "worker") -> None: + """Initialize the metadata service with specified number of pages.""" + pass + + @abstractmethod + def allocate_pages_for_keys( + self, rank: int, keys: list[tuple[str, str]] + ) -> list[tuple[str, int]]: + """Allocate one page for each key on the specified rank.""" + pass + + @abstractmethod + def confirm_write_for_keys( + self, + rank: int, + key_confirmations: list[tuple[str, int]], + pages_to_release: list[int] | None = None, + ) -> None: + """Confirm write operations for keys and optionally release pages.""" + pass + + @abstractmethod + def batch_key_exists(self, keys: list[str]) -> list[bool]: + """Check if keys exist and are complete across all ranks.""" + pass + + @abstractmethod + def get_key_locations(self, rank: int, keys: list[str]) -> list[int]: + """Get page indices for keys on a specific rank.""" + pass + + +class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface): + """Global HTTP metadata client for HF3FS.""" + + def __init__(self, base_url: str = "http://localhost:18000", max_retries: int = 3): + self.base_url = base_url.rstrip("/") + self._session = requests.Session() + + retry_strategy = Retry( + total=max_retries, + backoff_factor=0.3, + status_forcelist=[500, 502, 503, 504], + allowed_methods=["GET", "POST"], + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + self._session.mount("http://", adapter) + + def _post(self, endpoint: str, json_data: dict) -> dict: + """Make POST request to metadata server.""" + try: + url = f"{self.base_url}/{endpoint}" + headers = {"Content-Type": "application/json"} + if HAS_ORJSON: + payload = orjson.dumps(json_data) + else: + import json + + payload = json.dumps(json_data).encode("utf-8") + response = self._session.post(url, data=payload, headers=headers) + response.raise_for_status() + + if response.status_code == 204 or not response.content: + return {} + if HAS_ORJSON: + return orjson.loads(response.content) + else: + return response.json() + except requests.exceptions.RequestException as e: + logger.error("Failed to POST to %s after retries: %s", endpoint, e) + raise RuntimeError(f"Failed to connect to metadata server: {e}") from e + + def initialize(self, rank: int, num_pages: int = 0, role: str = "worker") -> None: + """Initialize a rank with specified number of pages.""" + self._post(f"rank/{rank}/initialize", {"num_pages": num_pages, "role": role}) + + def allocate_pages_for_keys( + self, rank: int, keys: list[tuple[str, str]] + ) -> list[tuple[str, int]]: + """Allocate pages for keys on the specified rank.""" + response = self._post("keys/batch_allocate", {"rank": rank, "keys": keys}) + + # Convert response to expected format + return response.get("results", {}) + + def confirm_write_for_keys( + self, + rank: int, + key_confirmations: list[tuple[str, int]], + pages_to_release: list[int] | None = None, + ) -> None: + """Confirm write operations for keys and optionally release pages.""" + payload = { + "rank": rank, + "confirmations": key_confirmations, + "pages_to_release": pages_to_release or [], + } + + self._post("keys/confirm_write", payload) + + def batch_key_exists(self, keys: list[str]) -> list[bool]: + """Check if keys exist and are complete across all ranks.""" + response = self._post("keys/batch_exists", {"keys": keys}) + return response.get("exists", []) + + def get_key_locations(self, rank: int, keys: list[str]) -> list[int]: + """Get page indices for keys on a specific rank.""" + response = self._post("keys/get_locations", {"rank": rank, "keys": keys}) + return response.get("locations", []) + + +def run_metadata_server( + host: str = "0.0.0.0", + port: int = 18000, +): + """Run the improved HF3FS metadata server.""" + server = Hf3fsMetadataServer() + server.run(host=host, port=port) + + +# --- Main Execution --- +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Improved HF3FS Metadata Server") + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host to bind the server to." + ) + parser.add_argument( + "--port", type=int, default=18000, help="Port to run the server on." + ) + args = parser.parse_args() + + run_metadata_server(args.host, args.port) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/common.py b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/common.py new file mode 100644 index 000000000..b47de73c9 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/common.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +from dataclasses import dataclass, field +from typing import Optional + +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.v1.request import Request + + +class AtomicCounter: + """Thread-safe atomic counter for round-robin operations.""" + + def __init__(self, n: int): + assert n > 0, "Counter size must be positive" + self._n = n + self._value = 0 + self._lock = threading.Lock() + + def next(self) -> int: + """Get next value in round-robin fashion.""" + with self._lock: + current = self._value + self._value = (current + 1) % self._n + return current + + +@dataclass +class LoadBlockInfo: + """Operation for loading blocks from external storage.""" + + num_computed_blocks: int + num_blocks_to_load: int + need_fetch_block_ids: list[int] + + +@dataclass +class SaveBlockInfo: + """Operation for saving blocks to external storage.""" + + skip_leading_blocks: int + + +@dataclass +class RequestSchedulingState: + """Unified request scheduling state management.""" + + request_id: str + request: Request | None = None + + # Token and block tracking + token_ids: list[int] = field(default_factory=list) + allocated_block_ids: list[int] = field(default_factory=list) + num_saved_blocks: int = 0 + + # Load operation info + load_op: LoadBlockInfo | None = None + + # Scheduling phase + phase: str = "NEW" # NEW -> WAITING_TO_LOAD -> ACTIVE -> FINISHED + + def needs_loading(self) -> bool: + """Check if request needs loading.""" + return self.load_op is not None and self.load_op.num_blocks_to_load > 0 + + def is_ready_to_load(self) -> bool: + """Check if request is ready for loading.""" + return self.phase == "WAITING_TO_LOAD" and self.needs_loading() + + def update_tokens_and_blocks(self, new_token_ids: list[int], new_block_ids) -> None: + """Update with new tokens and blocks.""" + if new_token_ids: + self.token_ids.extend(new_token_ids) + + if new_block_ids is not None: + normalized_block_ids = self._normalize_block_ids(new_block_ids) + self.allocated_block_ids.extend(normalized_block_ids) + + def _normalize_block_ids(self, block_ids) -> list[int]: + """Normalize block_ids to list format.""" + if not block_ids: + return [] + if isinstance(block_ids, tuple): + return block_ids[0] if block_ids else [] + if isinstance(block_ids, list): + return block_ids + return [] + + +@dataclass +class HF3FSRequestMetadata: + """Metadata for a single request in HF3FS connector.""" + + request_id: str + token_ids: list[int] + block_ids: list[int] + load_block_op: LoadBlockInfo | None = None + save_block_op: SaveBlockInfo | None = None + + @staticmethod + def from_scheduling_state( + state: "RequestSchedulingState", + block_size: int, + load_op: LoadBlockInfo | None = None, + skip_leading_blocks: int | None = None, + ) -> Optional["HF3FSRequestMetadata"]: + """Create request metadata from scheduling state.""" + token_count = len(state.token_ids) + total_blocks = token_count // block_size + + skip_blocks = ( + state.num_saved_blocks + if skip_leading_blocks is None + else skip_leading_blocks + ) + + new_blocks_to_save = total_blocks - state.num_saved_blocks + if new_blocks_to_save <= 0 and load_op is None: + return None + + state.num_saved_blocks = total_blocks + return HF3FSRequestMetadata( + request_id=state.request_id, + token_ids=state.token_ids, + block_ids=state.allocated_block_ids, + load_block_op=load_op, + save_block_op=SaveBlockInfo(skip_leading_blocks=skip_blocks), + ) + + +class HF3FSConnectorMetadata(KVConnectorMetadata): + """Container for HF3FS connector metadata.""" + + def __init__(self): + self.requests: list[HF3FSRequestMetadata] = [] + + def add_request(self, request_metadata: HF3FSRequestMetadata) -> None: + """Add request to metadata.""" + self.requests.append(request_metadata) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/gather_scatter_helper.py b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/gather_scatter_helper.py new file mode 100644 index 000000000..39d852dae --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/gather_scatter_helper.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.logger import init_logger +from vllm.triton_utils import tl, triton + + +@triton.jit +def kv_cache_scatter_kernel( + kv_cache_ptrs_ptr, + source_ptr, + token_indices_ptr, + num_tokens_in_block, + hidden_size, + total_token_in_kvcache, + num_layers, + is_mla, + BLOCK_SIZE: tl.constexpr, +): + layer_idx = tl.program_id(0) + token_pos = tl.program_id(1) + + if layer_idx >= num_layers or token_pos >= num_tokens_in_block: + return + + token_idx = tl.load(token_indices_ptr + token_pos) + kv_cache_ptr = tl.cast(tl.load(kv_cache_ptrs_ptr + layer_idx), source_ptr.dtype) + + if token_idx >= total_token_in_kvcache: + return + + if is_mla: + # MLA format: source [num_layers, num_tokens_in_block, hidden_size] + # MLA format: target [total_token_in_kvcache, hidden_size] (per layer) + source_offset = (layer_idx * num_tokens_in_block + token_pos) * hidden_size + target_offset = token_idx * hidden_size + + for i in range(0, hidden_size, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + val = tl.load(source_ptr + source_offset + offset, mask=mask) + tl.store(kv_cache_ptr + target_offset + offset, val, mask=mask) + else: + # MHA format: source [num_layers, 2, num_tokens_in_block, hidden_size] + # MHA format: target [2, total_token_in_kvcache, hidden_size] + source_offset_k = ( + layer_idx * num_tokens_in_block * 2 + token_pos + ) * hidden_size + source_offset_v = ( + layer_idx * num_tokens_in_block * 2 + num_tokens_in_block + token_pos + ) * hidden_size + + target_offset_k = token_idx * hidden_size + target_offset_v = (total_token_in_kvcache + token_idx) * hidden_size + + for i in range(0, hidden_size, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + + val_k = tl.load(source_ptr + source_offset_k + offset, mask=mask) + val_v = tl.load(source_ptr + source_offset_v + offset, mask=mask) + + tl.store(kv_cache_ptr + target_offset_k + offset, val_k, mask=mask) + tl.store(kv_cache_ptr + target_offset_v + offset, val_v, mask=mask) + + +@triton.jit +def kv_cache_gather_kernel( + kv_cache_ptrs_ptr, + dst_ptr, + token_indices_ptr, + num_tokens_in_block, + hidden_size, + total_token_in_kvcache, + num_layers, + is_mla, + BLOCK_SIZE: tl.constexpr, +): + layer_idx = tl.program_id(0) + token_pos = tl.program_id(1) + + if layer_idx >= num_layers or token_pos >= num_tokens_in_block: + return + + token_idx = tl.load(token_indices_ptr + token_pos) + kv_cache_ptr = tl.cast(tl.load(kv_cache_ptrs_ptr + layer_idx), dst_ptr.dtype) + + if token_idx >= total_token_in_kvcache: + return + + if is_mla: + # MLA format: source [total_token_in_kvcache, hidden_size] (per layer) + # MLA format: dst [num_layers, num_tokens_in_block, hidden_size] + kvcache_offset = token_idx * hidden_size + dst_offset = (layer_idx * num_tokens_in_block + token_pos) * hidden_size + + for i in range(0, hidden_size, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + val = tl.load(kv_cache_ptr + kvcache_offset + offset, mask=mask) + tl.store(dst_ptr + dst_offset + offset, val, mask=mask) + else: + # MHA format: source [2, total_token_in_kvcache, hidden_size] + # MHA format: dst [num_layers, 2, num_tokens_in_block, hidden_size] + dst_offset_k = (layer_idx * num_tokens_in_block * 2 + token_pos) * hidden_size + dst_offset_v = ( + layer_idx * num_tokens_in_block * 2 + num_tokens_in_block + token_pos + ) * hidden_size + + kvcache_offset_k = token_idx * hidden_size + kvcache_offset_v = (total_token_in_kvcache + token_idx) * hidden_size + + for i in range(0, hidden_size, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + + val_k = tl.load(kv_cache_ptr + kvcache_offset_k + offset, mask=mask) + val_v = tl.load(kv_cache_ptr + kvcache_offset_v + offset, mask=mask) + + tl.store(dst_ptr + dst_offset_k + offset, val_k, mask=mask) + tl.store(dst_ptr + dst_offset_v + offset, val_v, mask=mask) + + +def scatter_kv_caches( + kv_caches_ptrs: torch.Tensor, + total_token_in_kvcache: int, + src_tensor: torch.Tensor, + token_indices: list[int], + is_mla: bool = False, +) -> None: + """Scatter KV cache data from source tensor to KV cache storage. + + Args: + kv_caches_ptrs: Tensor of KV cache pointers (one per layer) + total_token_in_kvcache: Total number of tokens in KV cache + src_tensor: Source tensor containing data to scatter + - MHA format: [num_layers, 2, num_tokens_in_block, hidden_size] + - MLA format: [num_layers, num_tokens_in_block, hidden_size] + token_indices: List of token positions to update + is_mla: Whether using MLA model format + """ + num_layers = len(kv_caches_ptrs) + num_tokens_in_block = len(token_indices) + + if is_mla: + # MLA: src_tensor is [num_layers, num_tokens_in_block, hidden_size] + assert len(src_tensor.shape) == 3, ( + f"MLA src_tensor should be 3D, got {src_tensor.shape}" + ) + hidden_size = src_tensor.shape[2] + else: + # MHA: src_tensor is [num_layers, 2, num_tokens_in_block, hidden_size] + assert len(src_tensor.shape) == 4, ( + f"MHA src_tensor should be 4D, got {src_tensor.shape}" + ) + hidden_size = src_tensor.shape[3] + + device = src_tensor.device + token_indices_tensor = torch.tensor( + token_indices, dtype=torch.int32, device="cpu" + ).to(device, non_blocking=True) + + grid = (num_layers, num_tokens_in_block) + BLOCK_SIZE = 128 + + kv_cache_scatter_kernel[grid]( + kv_caches_ptrs, + src_tensor, + token_indices_tensor, + num_tokens_in_block, + hidden_size, + total_token_in_kvcache, + num_layers, + is_mla, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def gather_kv_caches( + kv_caches_ptrs: torch.Tensor, + total_token_in_kvcache: int, + dst_tensor: torch.Tensor, + token_indices: list[int], + is_mla: bool = False, +) -> None: + """Gather KV cache data from KV cache storage to destination tensor. + + Args: + kv_caches_ptrs: Tensor of KV cache pointers (one per layer) + total_token_in_kvcache: Total number of tokens in KV cache + dst_tensor: Destination tensor to store gathered data + - MHA format: [num_layers, 2, num_tokens_in_block, hidden_size] + - MLA format: [num_layers, num_tokens_in_block, hidden_size] + token_indices: List of token positions to gather + is_mla: Whether using MLA model format + """ + num_layers = kv_caches_ptrs.shape[0] + num_tokens_in_block = len(token_indices) + + if is_mla: + # MLA: dst_tensor is [num_layers, num_tokens_in_block, hidden_size] + assert len(dst_tensor.shape) == 3, ( + f"MLA dst_tensor should be 3D, got {dst_tensor.shape}" + ) + assert dst_tensor.shape[0] == num_layers, ( + f"Layer count mismatch: {dst_tensor.shape[0]} vs {num_layers}" + ) + assert dst_tensor.shape[1] == num_tokens_in_block, ( + f"Token count mismatch: {dst_tensor.shape[1]} vs {num_tokens_in_block}" + ) + hidden_size = dst_tensor.shape[2] + else: + # MHA: dst_tensor is [num_layers, 2, num_tokens_in_block, hidden_size] + assert len(dst_tensor.shape) == 4, ( + f"MHA dst_tensor should be 4D, got {dst_tensor.shape}" + ) + assert dst_tensor.shape[0] == num_layers, ( + f"Layer count mismatch: {dst_tensor.shape[0]} vs {num_layers}" + ) + assert dst_tensor.shape[1] == 2, ( + f"MHA should have 2 (K,V) components, got {dst_tensor.shape[1]}" + ) + assert dst_tensor.shape[2] == num_tokens_in_block, ( + f"Token count mismatch: {dst_tensor.shape[2]} vs {num_tokens_in_block}" + ) + hidden_size = dst_tensor.shape[3] + + device = dst_tensor.device + token_indices_tensor = torch.tensor( + token_indices, dtype=torch.int32, device="cpu" + ).to(device, non_blocking=True) + + grid = (num_layers, num_tokens_in_block) + BLOCK_SIZE = 128 + + kv_cache_gather_kernel[grid]( + kv_caches_ptrs, + dst_tensor, + token_indices_tensor, + num_tokens_in_block, + hidden_size, + total_token_in_kvcache, + num_layers, + is_mla, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +class CopyBufferAllocator: + """Memory pool for tensor buffers to avoid frequent allocation/deallocation.""" + + def __init__( + self, device: torch.device, dtype: torch.dtype, shape: list, max_count: int + ): + self._shape = shape + self._max_count = max_count + self._device = device + self._free_buffers = [ + torch.empty(shape, dtype=dtype, device=device) for _ in range(max_count) + ] + self._inuse_count = 0 + + def alloc_buffer(self, count: int) -> list[torch.Tensor] | None: + """Allocate buffers from the pool.""" + if count == 0: + return [] + + if self._inuse_count + count <= self._max_count: + self._inuse_count += count + result = self._free_buffers[-count:] + del self._free_buffers[-count:] + return result + return None + + def free_buffer(self, buffers: list[torch.Tensor]) -> None: + """Return buffers to the pool.""" + if not buffers: + return + + if self._inuse_count >= len(buffers): + self._inuse_count -= len(buffers) + self._free_buffers.extend(buffers) + else: + raise RuntimeError("Attempted to free more buffers than allocated") + + +logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/hf3fs_mock_client.py b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/hf3fs_mock_client.py new file mode 100644 index 000000000..3914663a6 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/hf3fs_mock_client.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +import os + +import torch + +logger = logging.getLogger(__name__) +HF3FS_AVAILABLE = True + + +class Hf3fsClient: + """Mock HF3FS client using file backend for debugging and testing.""" + + def __init__(self, path: str, size: int, bytes_per_page: int, entries: int): + self._size = size + self._bytes_per_page = bytes_per_page + self._entries = entries + self._file_path = path + + self._ensure_file_exists() + logger.debug("Initialized mock HF3FS client: %s (%d bytes)", path, size) + + def _ensure_file_exists(self) -> None: + """Create file if it doesn't exist.""" + if not os.path.exists(self._file_path): + with open(self._file_path, "w+b") as f: + f.truncate(self._size) + + def batch_read(self, offsets: list[int], tensors: list[torch.Tensor]) -> list[int]: + """Read data from file at specified offsets into tensors.""" + results = [] + + try: + with open(self._file_path, "rb") as f: + for offset, tensor in zip(offsets, tensors): + num_bytes = tensor.numel() * tensor.element_size() + + if offset < 0 or offset + num_bytes > self._size: + results.append(-1) + continue + + f.seek(offset) + buffer_data = f.read(num_bytes) + + if len(buffer_data) == num_bytes == self._bytes_per_page: + tensor_data = self._convert_buffer_to_tensor( + buffer_data, tensor.dtype + ) + tensor.copy_( + tensor_data.reshape(tensor.shape).to(tensor.device) + ) + results.append(self._bytes_per_page) + else: + logger.error( + "Read size mismatch: got %d, expected %d", + len(buffer_data), + num_bytes, + ) + results.append(-1) + except Exception as e: + logger.error("Batch read error: %s", e) + results.extend([-1] * (len(offsets) - len(results))) + + return results + + def _convert_buffer_to_tensor( + self, buffer_data: bytes, dtype: torch.dtype + ) -> torch.Tensor: + """Convert buffer data to tensor with proper dtype handling.""" + if dtype == torch.bfloat16: + tensor_data = torch.frombuffer(buffer_data, dtype=torch.uint16) + return tensor_data.view(dtype=torch.bfloat16) + else: + return torch.frombuffer(buffer_data, dtype=dtype) + + def batch_write( + self, offsets: list[int], tensors: list[torch.Tensor], event: torch.cuda.Event + ) -> list[int]: + """Write data from tensors to file at specified offsets.""" + results = [] + + try: + torch.cuda.current_stream().wait_event(event) + + # Convert tensors to bytes + data_bytes_list = [self._tensor_to_bytes(tensor) for tensor in tensors] + + # Write to file + with open(self._file_path, "r+b") as f: + for offset, data_bytes in zip(offsets, data_bytes_list): + if offset < 0 or offset + len(data_bytes) > self._size: + results.append(-1) + continue + + f.seek(offset) + bytes_written = f.write(data_bytes) + + if bytes_written == len(data_bytes) == self._bytes_per_page: + results.append(self._bytes_per_page) + else: + logger.error( + "Write size mismatch: wrote %d, expected %d", + bytes_written, + self._bytes_per_page, + ) + results.append(-1) + + except Exception as e: + logger.error("Batch write error: %s", e) + results.extend([-1] * (len(offsets) - len(results))) + + return results + + def _tensor_to_bytes(self, tensor: torch.Tensor) -> bytes: + """Convert tensor to bytes with proper dtype handling.""" + cpu_tensor = tensor.cpu() + if cpu_tensor.dtype == torch.bfloat16: + return cpu_tensor.view(dtype=torch.uint16).numpy().tobytes() + else: + return cpu_tensor.numpy().tobytes() + + def get_size(self) -> int: + """Get the total size of the storage file.""" + return self._size + + def close(self) -> None: + """Close the client (no-op for file backend).""" + pass + + def flush(self) -> None: + """Flush any pending writes (no-op for file backend).""" + pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/hf3fs_utils.cpp b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/hf3fs_utils.cpp new file mode 100644 index 000000000..9dbeb251d --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/hf3fs/utils/hf3fs_utils.cpp @@ -0,0 +1,57 @@ +#include +#include +#include +#include + +void read_shm(const torch::Tensor& shm, const torch::Tensor& pin, + std::vector dst, uint64_t stream_ptr) { + py::gil_scoped_release release; + + cudaStream_t stream = reinterpret_cast(stream_ptr); + + // Copy from shared memory to pinned memory + char* shm_ptr = static_cast(shm.data_ptr()); + char* src_ptr = static_cast(pin.data_ptr()); + std::memcpy(src_ptr, shm_ptr, shm.numel() * shm.element_size()); + + // Copy from pinned memory to GPU tensors + size_t current = 0; + for (size_t i = 0; i < dst.size(); ++i) { + auto& t = dst[i]; + size_t t_bytes = t.numel() * t.element_size(); + char* dst_ptr = static_cast(t.data_ptr()); + cudaMemcpyAsync(dst_ptr, src_ptr + current, t_bytes, cudaMemcpyHostToDevice, + stream); + current += t_bytes; + } + cudaStreamSynchronize(stream); +} + +void write_shm(const std::vector src, torch::Tensor& shm, + const torch::Tensor& pin, uint64_t stream_ptr) { + py::gil_scoped_release release; + + cudaStream_t stream = reinterpret_cast(stream_ptr); + + // Copy from GPU tensors to pinned memory + char* dst_ptr = static_cast(pin.data_ptr()); + size_t current = 0; + for (size_t i = 0; i < src.size(); ++i) { + auto& t = src[i]; + size_t t_bytes = t.numel() * t.element_size(); + char* src_ptr = static_cast(t.data_ptr()); + cudaMemcpyAsync(dst_ptr + current, src_ptr, t_bytes, cudaMemcpyDeviceToHost, + stream); + current += t_bytes; + } + cudaStreamSynchronize(stream); + + // Copy from pinned memory to shared memory + char* shm_ptr = static_cast(shm.data_ptr()); + std::memcpy(shm_ptr, dst_ptr, shm.numel() * shm.element_size()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("read_shm", &read_shm, "Read tensors from shared memory"); + m.def("write_shm", &write_shm, "Write tensors to shared memory"); +} \ No newline at end of file