[Core] Combine async postprocessor and multi-step (#7921)

This commit is contained in:
Alexander Matveev
2024-08-29 14:18:26 -04:00
committed by GitHub
parent f205c09854
commit 3f60f2244e
8 changed files with 215 additions and 65 deletions

View File

@@ -279,6 +279,10 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx = self.scheduler_contexts[virtual_engine]
# skip the scheduler if there are any remaining steps in the seq groups.
@@ -289,17 +293,27 @@ class _AsyncLLMEngine(LLMEngine):
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
@@ -311,9 +325,6 @@ class _AsyncLLMEngine(LLMEngine):
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
@@ -339,8 +350,13 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callback[
virtual_engine]
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]
execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step
# Execute the model.
output = await self.model_executor.execute_model_async(
@@ -350,7 +366,7 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(ctx.output_queue) > 0:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
@@ -362,22 +378,25 @@ class _AsyncLLMEngine(LLMEngine):
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
# Clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
# Cache results in engine
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
@@ -390,7 +409,11 @@ class _AsyncLLMEngine(LLMEngine):
self.do_tracing(scheduler_outputs)
else:
ctx.request_outputs = []
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)