diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index 103675608..d3db828dc 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -22,6 +22,17 @@ if current_platform.is_cuda(): elif current_platform.is_rocm(): ATTN_BACKENDS = ["TRITON_ATTN"] +# Maximum time (seconds) to wait for the async CPU offload transfer +# to complete before giving up. +_RESET_CACHE_TIMEOUT = 30 if current_platform.is_rocm() else 10 + +# ZMQ poll timeout (ms) for the first event. +_FIRST_EVENT_POLL_MS = 10_000 if current_platform.is_rocm() else 1000 + +# Hard ceiling (seconds) on how long get_new_cpu_stored_events may loop, +# to prevent hangs if non-CPU events keep arriving indefinitely. +_EVENT_DRAIN_TIMEOUT = 60 + class MockSubscriber: """Helper class to receive and verify published events""" @@ -47,9 +58,10 @@ class MockSubscriber: poller = zmq.Poller() poller.register(self.sub, zmq.POLLIN) - timeout = 1000 # 1 second - while True: - events = dict(poller.poll(timeout)) + poll_ms = _FIRST_EVENT_POLL_MS + deadline = time.monotonic() + _EVENT_DRAIN_TIMEOUT + while time.monotonic() < deadline: + events = dict(poller.poll(poll_ms)) if events.get(self.sub) != zmq.POLLIN: return cpu_stored_events @@ -63,13 +75,32 @@ class MockSubscriber: for event in event_batch.events: if isinstance(event, BlockStored) and event.medium == "CPU": cpu_stored_events.append(event) - timeout = 100 + poll_ms = 100 + + return cpu_stored_events def close(self): """Clean up resources""" self.sub.close() +def _wait_for_prefix_cache_reset(llm: LLM) -> None: + """Wait for async offload transfers to finish so prefix cache can reset. + + The GPU-to-CPU offload runs on a CUDA stream asynchronously. While blocks + are still held by the offload worker, ``reset_prefix_cache`` returns + ``False``. Retry with a short sleep until it succeeds or we time out. + """ + deadline = time.monotonic() + _RESET_CACHE_TIMEOUT + while not llm.reset_prefix_cache(): + if time.monotonic() > deadline: + raise TimeoutError( + "reset_prefix_cache did not succeed within " + f"{_RESET_CACHE_TIMEOUT}s - async offload may be stuck" + ) + time.sleep(0.1) + + def _latency_test(llm: LLM, subscriber: MockSubscriber): sampling_params = SamplingParams(max_tokens=1) @@ -95,10 +126,16 @@ def _latency_test(llm: LLM, subscriber: MockSubscriber): gpu_hit_time = time.time() - start_time total_gpu_hit_time += gpu_hit_time - # reset prefix cache to avoid GPU hit. - llm.reset_prefix_cache() + # Wait for the async CPU offload to finish, then reset prefix cache + # so the next generate() must reload from CPU rather than GPU. + _wait_for_prefix_cache_reset(llm) - assert subscriber.get_new_cpu_stored_events() + # Verify CPU stored events arrived (offload is done before we + # attempt to load from CPU). + assert subscriber.get_new_cpu_stored_events(), ( + f"No CPU stored events received on iteration {i}; " + "async offload may not have completed in time" + ) # run generation again - this should trigger loading from CPU start_time = time.time() @@ -185,6 +222,8 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None: kv_events_config=kv_events_config, kv_transfer_config=kv_transfer_config, attention_config={"backend": attn_backend}, + # ROCm: batch size 1 to reduce variability + **({"max_num_seqs": 1} if current_platform.is_rocm() else {}), ) events_endpoint = events_endpoint.replace("*", "127.0.0.1")