[Core] Support reseting all running requests' KV while calling reset_prefix_cache (#28827)
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -45,6 +45,12 @@ class AsyncScheduler(Scheduler):
|
||||
request: Request,
|
||||
new_token_ids: list[int],
|
||||
) -> tuple[list[int], bool]:
|
||||
if request.discard_latest_async_tokens:
|
||||
# If the request is force preempted in reset_prefix_cache, we
|
||||
# should discard the latest async token.
|
||||
request.discard_latest_async_tokens = False
|
||||
return [], False
|
||||
|
||||
status_before_update = request.status
|
||||
new_token_ids, stopped = super()._update_request_with_output(
|
||||
request, new_token_ids
|
||||
|
||||
@@ -152,10 +152,16 @@ class SchedulerInterface(ABC):
|
||||
return self.has_unfinished_requests() or self.has_finished_requests()
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
"""Reset the prefix cache for KV cache.
|
||||
|
||||
This is particularly required when the model weights are live-updated.
|
||||
|
||||
Args:
|
||||
reset_running_requests: If True, all the running requests will be
|
||||
preempted and moved to the waiting queue. Otherwise, this method
|
||||
will only reset the KV prefix cache when there is no running request
|
||||
taking KV cache.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -347,17 +347,7 @@ class Scheduler(SchedulerInterface):
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
self.encoder_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
preempted_req.num_preemptions += 1
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp
|
||||
)
|
||||
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
self._preempt_request(preempted_req, scheduled_timestamp)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt. Cannot schedule this request.
|
||||
@@ -756,6 +746,30 @@ class Scheduler(SchedulerInterface):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
def _preempt_request(
|
||||
self,
|
||||
request: Request,
|
||||
timestamp: float,
|
||||
) -> None:
|
||||
"""Preempt a request and put it back to the waiting queue.
|
||||
|
||||
NOTE: The request should be popped from the running queue outside of this
|
||||
method.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING, (
|
||||
"Only running requests can be preempted"
|
||||
)
|
||||
self.kv_cache_manager.free(request)
|
||||
self.encoder_cache_manager.free(request)
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
request.num_computed_tokens = 0
|
||||
request.num_preemptions += 1
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
|
||||
|
||||
# Put the request back to the waiting queue.
|
||||
self.waiting.prepend_request(request)
|
||||
|
||||
def _update_after_schedule(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
@@ -1362,8 +1376,45 @@ class Scheduler(SchedulerInterface):
|
||||
def has_finished_requests(self) -> bool:
|
||||
return len(self.finished_req_ids) > 0
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.kv_cache_manager.reset_prefix_cache()
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
"""Reset the KV prefix cache.
|
||||
|
||||
If reset_running_requests is True, all the running requests will be
|
||||
preempted and moved to the waiting queue.
|
||||
Otherwise, this method will only reset the KV prefix cache when there
|
||||
is no running requests taking KV cache.
|
||||
"""
|
||||
if reset_running_requests:
|
||||
# For logging.
|
||||
timestamp = time.monotonic()
|
||||
# Invalidate all the current running requests KV's by pushing them to
|
||||
# the waiting queue. In this case, we can reduce the ref count of all
|
||||
# the kv blocks to 0 and thus we can make sure the reset is successful.
|
||||
# Preempt in reverse order so the requests will be added back to the
|
||||
# running queue in FIFO order.
|
||||
while self.running:
|
||||
request = self.running.pop()
|
||||
self._preempt_request(request, timestamp)
|
||||
# NOTE(zhuohan): For async scheduling, we need to discard the latest
|
||||
# output token on the fly to avoid a redundant repetitive output token.
|
||||
request.num_output_placeholders = 0
|
||||
request.discard_latest_async_tokens = True
|
||||
|
||||
# Clear scheduled request ids cache. Since we are forcing preemption
|
||||
# + resumption in the same step, we must act as if these requests were
|
||||
# not scheduled in the prior step. They will be flushed from the
|
||||
# persistent batch in the model runner.
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
|
||||
reset_successful = self.kv_cache_manager.reset_prefix_cache()
|
||||
if reset_running_requests and not reset_successful:
|
||||
raise RuntimeError(
|
||||
"Failed to reset KV cache even when all the running requests are "
|
||||
"preempted and moved to the waiting queue. This is likely due to "
|
||||
"the presence of running requests waiting for remote KV transfer, "
|
||||
"which is not supported yet."
|
||||
)
|
||||
return reset_successful
|
||||
|
||||
def make_stats(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user