diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index e0645ed43..a77ae81ba 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -236,7 +236,7 @@ def test_prefix_caching_for_multi_turn(): req._all_token_ids = req.prompt_token_ids.copy() req.all_token_ids = ConstantList(req._all_token_ids) req.block_hashes = [] - req.block_hashes = req.get_hash_new_full_blocks() + req.update_block_hashes() # Schedule the next-turn requests. for req in next_turn_requests: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 90ca58441..aa3bc6e2c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -982,10 +982,8 @@ class Scheduler(SchedulerInterface): session._all_token_ids.extend(update.prompt_token_ids or ()) session.prompt_token_ids.extend(update.prompt_token_ids or ()) - # Update block hashes for the new tokens - # (mirrors Request.append_output_token_ids) - if session.get_hash_new_full_blocks is not None: - session.block_hashes.extend(session.get_hash_new_full_blocks()) + # Update block hashes for the new tokens. + session.update_block_hashes() session.num_prompt_tokens = len(session.prompt_token_ids) session.arrival_time = update.arrival_time session.sampling_params = update.sampling_params diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 3b829875f..970b7e1eb 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -6,7 +6,6 @@ import time from collections import deque from collections.abc import Callable, Mapping from dataclasses import dataclass -from functools import partial from typing import TYPE_CHECKING, Any import torch @@ -164,10 +163,11 @@ class Request: self.num_external_computed_tokens = 0 self.block_hashes: list[BlockHash] = [] - self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None - if block_hasher is not None: - self.get_hash_new_full_blocks = partial(block_hasher, self) - self.block_hashes = self.get_hash_new_full_blocks() + # Store the block hasher without binding self to avoid creating a + # reference cycle (Request -> partial -> Request) that prevents + # immediate garbage collection via reference counting. + self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher + self.update_block_hashes() self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache() @@ -212,8 +212,12 @@ class Request: self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) - if self.get_hash_new_full_blocks is not None: - self.block_hashes.extend(self.get_hash_new_full_blocks()) + self.update_block_hashes() + + def update_block_hashes(self) -> None: + """Compute block hashes for any new full blocks and append them.""" + if self._block_hasher is not None: + self.block_hashes.extend(self._block_hasher(self)) @property def use_structured_output(self) -> bool: