[Optimization] Use Shared CachedRequestData Instance Across All Requests (#20232)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-06-30 09:07:50 -07:00
committed by GitHub
parent 2965c99c86
commit 2863befce3
12 changed files with 220 additions and 231 deletions

View File

@@ -83,29 +83,27 @@ class NewRequestData:
@dataclass
class CachedRequestData:
req_id: str
req_ids: list[str]
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: list[int]
new_block_ids: tuple[list[int], ...]
num_computed_tokens: int
resumed_from_preemption: list[bool]
new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int]
@property
def num_reqs(self) -> int:
return len(self.req_ids)
@classmethod
def from_request(
cls,
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: tuple[list[int], ...],
) -> CachedRequestData:
def make_empty(cls) -> CachedRequestData:
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=request.num_computed_tokens,
req_ids=[],
resumed_from_preemption=[],
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
)
@@ -119,7 +117,7 @@ class SchedulerOutput:
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: list[CachedRequestData]
scheduled_cached_reqs: CachedRequestData
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.

View File

@@ -3,8 +3,9 @@
from __future__ import annotations
import itertools
import time
from collections import defaultdict, deque
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Optional, Union
@@ -117,12 +118,6 @@ class Scheduler(SchedulerInterface):
# KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> deque of CachedRequestData
self._cached_reqs_data: dict[
str, deque[CachedRequestData]] = defaultdict(deque)
# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
@@ -547,27 +542,16 @@ class Scheduler(SchedulerInterface):
req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
resumed_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=True,
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
@@ -613,34 +597,39 @@ class Scheduler(SchedulerInterface):
def _make_cached_request_data(
self,
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: tuple[list[int], ...],
resumed_from_preemption: bool,
running_reqs: list[Request],
resumed_reqs: list[Request],
num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
num_computed_tokens = request.num_computed_tokens
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = []
num_computed_tokens: list[int] = []
req_data_queue = self._cached_reqs_data.get(request.request_id)
if req_data_queue:
req_data = req_data_queue.popleft()
req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
# No cached request data, or all cached request data has been
# used by the scheduled requests.
req_data = CachedRequestData.from_request(request,
resumed_from_preemption,
new_token_ids,
new_block_ids)
return req_data
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] -
len(spec_decode_tokens.get(req_id, ())))
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption += [True] * len(resumed_reqs)
return CachedRequestData(
req_ids=req_ids,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)
def _try_schedule_encoder_inputs(
self,
@@ -870,19 +859,11 @@ class Scheduler(SchedulerInterface):
if not stopped:
new_running.append(request)
self.running = new_running
# KV Connector: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
# to _cached_reqs_data will cause a memory leak.
if req_data.req_id not in self.finished_req_ids:
self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs = {
@@ -965,13 +946,11 @@ class Scheduler(SchedulerInterface):
self._free_request(request)
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished()
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
request_id = request.request_id
self._cached_reqs_data.pop(request_id, None)
self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
@@ -983,7 +962,6 @@ class Scheduler(SchedulerInterface):
def _free_blocks(self, request: Request):
assert request.is_finished()
assert request.request_id not in self._cached_reqs_data
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]