[Bugfix][Core] Fix CPU memory leak from Request reference cycle in prefix caching (#34183)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -236,7 +236,7 @@ def test_prefix_caching_for_multi_turn():
|
||||
req._all_token_ids = req.prompt_token_ids.copy()
|
||||
req.all_token_ids = ConstantList(req._all_token_ids)
|
||||
req.block_hashes = []
|
||||
req.block_hashes = req.get_hash_new_full_blocks()
|
||||
req.update_block_hashes()
|
||||
|
||||
# Schedule the next-turn requests.
|
||||
for req in next_turn_requests:
|
||||
|
||||
@@ -982,10 +982,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
session._all_token_ids.extend(update.prompt_token_ids or ())
|
||||
session.prompt_token_ids.extend(update.prompt_token_ids or ())
|
||||
# Update block hashes for the new tokens
|
||||
# (mirrors Request.append_output_token_ids)
|
||||
if session.get_hash_new_full_blocks is not None:
|
||||
session.block_hashes.extend(session.get_hash_new_full_blocks())
|
||||
# Update block hashes for the new tokens.
|
||||
session.update_block_hashes()
|
||||
session.num_prompt_tokens = len(session.prompt_token_ids)
|
||||
session.arrival_time = update.arrival_time
|
||||
session.sampling_params = update.sampling_params
|
||||
|
||||
@@ -6,7 +6,6 @@ import time
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
@@ -164,10 +163,11 @@ class Request:
|
||||
self.num_external_computed_tokens = 0
|
||||
|
||||
self.block_hashes: list[BlockHash] = []
|
||||
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
|
||||
if block_hasher is not None:
|
||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
||||
self.block_hashes = self.get_hash_new_full_blocks()
|
||||
# Store the block hasher without binding self to avoid creating a
|
||||
# reference cycle (Request -> partial -> Request) that prevents
|
||||
# immediate garbage collection via reference counting.
|
||||
self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
|
||||
self.update_block_hashes()
|
||||
|
||||
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
|
||||
|
||||
@@ -212,8 +212,12 @@ class Request:
|
||||
self._output_token_ids.extend(token_ids)
|
||||
self._all_token_ids.extend(token_ids)
|
||||
|
||||
if self.get_hash_new_full_blocks is not None:
|
||||
self.block_hashes.extend(self.get_hash_new_full_blocks())
|
||||
self.update_block_hashes()
|
||||
|
||||
def update_block_hashes(self) -> None:
|
||||
"""Compute block hashes for any new full blocks and append them."""
|
||||
if self._block_hasher is not None:
|
||||
self.block_hashes.extend(self._block_hasher(self))
|
||||
|
||||
@property
|
||||
def use_structured_output(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user