[Optimization] Use Shared CachedRequestData Instance Across All Requests (#20232)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -83,29 +83,27 @@ class NewRequestData:
|
||||
@dataclass
|
||||
class CachedRequestData:
|
||||
|
||||
req_id: str
|
||||
req_ids: list[str]
|
||||
# If resumed_from_preemption is False, new_block_ids will be appended to
|
||||
# the request's block IDs. If True, new_block_ids will be used as the
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_from_preemption: bool
|
||||
new_token_ids: list[int]
|
||||
new_block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
resumed_from_preemption: list[bool]
|
||||
new_token_ids: list[list[int]]
|
||||
new_block_ids: list[tuple[list[int], ...]]
|
||||
num_computed_tokens: list[int]
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_ids)
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
resumed_from_preemption: bool,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: tuple[list[int], ...],
|
||||
) -> CachedRequestData:
|
||||
def make_empty(cls) -> CachedRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
resumed_from_preemption=resumed_from_preemption,
|
||||
new_token_ids=new_token_ids,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
req_ids=[],
|
||||
resumed_from_preemption=[],
|
||||
new_token_ids=[],
|
||||
new_block_ids=[],
|
||||
num_computed_tokens=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -119,7 +117,7 @@ class SchedulerOutput:
|
||||
# list of the requests that have been scheduled before.
|
||||
# Since the request's data is already cached in the worker processes,
|
||||
# we only send the diff to minimize the communication cost.
|
||||
scheduled_cached_reqs: list[CachedRequestData]
|
||||
scheduled_cached_reqs: CachedRequestData
|
||||
|
||||
# req_id -> num_scheduled_tokens
|
||||
# Number of tokens scheduled for each request.
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@@ -117,12 +118,6 @@ class Scheduler(SchedulerInterface):
|
||||
# KV Connector: requests in process of async KV loading or recving
|
||||
self.finished_recving_kv_req_ids: set[str] = set()
|
||||
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
# Request id -> deque of CachedRequestData
|
||||
self._cached_reqs_data: dict[
|
||||
str, deque[CachedRequestData]] = defaultdict(deque)
|
||||
|
||||
# Encoder-related.
|
||||
# Calculate encoder cache size if applicable
|
||||
# NOTE: For now we use the same budget for both compute and space.
|
||||
@@ -547,27 +542,16 @@ class Scheduler(SchedulerInterface):
|
||||
req_to_new_block_ids[req.request_id])
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
resumed_reqs_data = [
|
||||
self._make_cached_request_data(
|
||||
req,
|
||||
num_scheduled_tokens[req.request_id],
|
||||
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||
req_to_new_block_ids[req.request_id],
|
||||
resumed_from_preemption=True,
|
||||
) for req in scheduled_resumed_reqs
|
||||
]
|
||||
running_reqs_data = [
|
||||
self._make_cached_request_data(
|
||||
req,
|
||||
num_scheduled_tokens[req.request_id],
|
||||
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||
req_to_new_block_ids[req.request_id],
|
||||
resumed_from_preemption=False,
|
||||
) for req in scheduled_running_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_block_ids,
|
||||
)
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
||||
scheduled_cached_reqs=cached_reqs_data,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
@@ -613,34 +597,39 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
def _make_cached_request_data(
|
||||
self,
|
||||
request: Request,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_spec_tokens: int,
|
||||
new_block_ids: tuple[list[int], ...],
|
||||
resumed_from_preemption: bool,
|
||||
running_reqs: list[Request],
|
||||
resumed_reqs: list[Request],
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
spec_decode_tokens: dict[str, list[int]],
|
||||
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
|
||||
) -> CachedRequestData:
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_computed_tokens:num_computed_tokens + num_regular_tokens]
|
||||
req_ids: list[str] = []
|
||||
new_token_ids: list[list[int]] = []
|
||||
new_block_ids: list[tuple[list[int], ...]] = []
|
||||
num_computed_tokens: list[int] = []
|
||||
|
||||
req_data_queue = self._cached_reqs_data.get(request.request_id)
|
||||
if req_data_queue:
|
||||
req_data = req_data_queue.popleft()
|
||||
req_data.resumed_from_preemption = resumed_from_preemption
|
||||
req_data.new_token_ids = new_token_ids
|
||||
req_data.new_block_ids = new_block_ids
|
||||
req_data.num_computed_tokens = num_computed_tokens
|
||||
else:
|
||||
# No cached request data, or all cached request data has been
|
||||
# used by the scheduled requests.
|
||||
req_data = CachedRequestData.from_request(request,
|
||||
resumed_from_preemption,
|
||||
new_token_ids,
|
||||
new_block_ids)
|
||||
return req_data
|
||||
for req in itertools.chain(running_reqs, resumed_reqs):
|
||||
req_id = req.request_id
|
||||
req_ids.append(req_id)
|
||||
num_tokens = (num_scheduled_tokens[req_id] -
|
||||
len(spec_decode_tokens.get(req_id, ())))
|
||||
token_ids = req.all_token_ids[req.num_computed_tokens:req.
|
||||
num_computed_tokens + num_tokens]
|
||||
new_token_ids.append(token_ids)
|
||||
new_block_ids.append(req_to_new_block_ids[req_id])
|
||||
num_computed_tokens.append(req.num_computed_tokens)
|
||||
# Because resumed_reqs is usually empty, it is more efficient to do
|
||||
# in-place appending so that we don't need to allocate a new list.
|
||||
resumed_from_preemption = [False] * len(running_reqs)
|
||||
resumed_from_preemption += [True] * len(resumed_reqs)
|
||||
|
||||
return CachedRequestData(
|
||||
req_ids=req_ids,
|
||||
resumed_from_preemption=resumed_from_preemption,
|
||||
new_token_ids=new_token_ids,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
)
|
||||
|
||||
def _try_schedule_encoder_inputs(
|
||||
self,
|
||||
@@ -870,19 +859,11 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
if not stopped:
|
||||
new_running.append(request)
|
||||
self.running = new_running
|
||||
|
||||
# KV Connector: update state for finished KV Transfers.
|
||||
self._update_from_kv_xfer_finished(model_runner_output)
|
||||
|
||||
# Return the cached request data to the queue so they can be reused.
|
||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
||||
# to _cached_reqs_data will cause a memory leak.
|
||||
if req_data.req_id not in self.finished_req_ids:
|
||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||
|
||||
self.running = new_running
|
||||
|
||||
# Create EngineCoreOutputs for all clients that have requests with
|
||||
# outputs in this step.
|
||||
engine_core_outputs = {
|
||||
@@ -965,13 +946,11 @@ class Scheduler(SchedulerInterface):
|
||||
self._free_request(request)
|
||||
|
||||
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
|
||||
|
||||
assert request.is_finished()
|
||||
|
||||
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
||||
self.encoder_cache_manager.free(request)
|
||||
request_id = request.request_id
|
||||
self._cached_reqs_data.pop(request_id, None)
|
||||
self.finished_req_ids.add(request_id)
|
||||
if self.finished_req_ids_dict is not None:
|
||||
self.finished_req_ids_dict[request.client_index].add(request_id)
|
||||
@@ -983,7 +962,6 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
def _free_blocks(self, request: Request):
|
||||
assert request.is_finished()
|
||||
assert request.request_id not in self._cached_reqs_data
|
||||
self.kv_cache_manager.free(request)
|
||||
self.kv_cache_manager.free_block_hashes(request)
|
||||
del self.requests[request.request_id]
|
||||
|
||||
Reference in New Issue
Block a user