[Bugfix][Core] Fix CPU memory leak from Request reference cycle in prefix caching (#34183)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2026-02-09 21:03:32 -08:00
committed by GitHub
parent 4cde2e0159
commit 8a5e0e2b2b
3 changed files with 14 additions and 12 deletions

View File

@@ -236,7 +236,7 @@ def test_prefix_caching_for_multi_turn():
req._all_token_ids = req.prompt_token_ids.copy()
req.all_token_ids = ConstantList(req._all_token_ids)
req.block_hashes = []
req.block_hashes = req.get_hash_new_full_blocks()
req.update_block_hashes()
# Schedule the next-turn requests.
for req in next_turn_requests:

View File

@@ -982,10 +982,8 @@ class Scheduler(SchedulerInterface):
session._all_token_ids.extend(update.prompt_token_ids or ())
session.prompt_token_ids.extend(update.prompt_token_ids or ())
# Update block hashes for the new tokens
# (mirrors Request.append_output_token_ids)
if session.get_hash_new_full_blocks is not None:
session.block_hashes.extend(session.get_hash_new_full_blocks())
# Update block hashes for the new tokens.
session.update_block_hashes()
session.num_prompt_tokens = len(session.prompt_token_ids)
session.arrival_time = update.arrival_time
session.sampling_params = update.sampling_params

View File

@@ -6,7 +6,6 @@ import time
from collections import deque
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any
import torch
@@ -164,10 +163,11 @@ class Request:
self.num_external_computed_tokens = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
# Store the block hasher without binding self to avoid creating a
# reference cycle (Request -> partial -> Request) that prevents
# immediate garbage collection via reference counting.
self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
self.update_block_hashes()
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
@@ -212,8 +212,12 @@ class Request:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
self.update_block_hashes()
def update_block_hashes(self) -> None:
"""Compute block hashes for any new full blocks and append them."""
if self._block_hasher is not None:
self.block_hashes.extend(self._block_hasher(self))
@property
def use_structured_output(self) -> bool: