[Core] Asynchronous Output Processor (#7049)

Co-authored-by: Alexander Matveev <alexm@neuralmagic.com>
This commit is contained in:
Megha Agarwal
2024-08-26 20:53:20 -07:00
committed by GitHub
parent 015e6cc252
commit 2eedede875
21 changed files with 652 additions and 214 deletions

View File

@@ -277,23 +277,36 @@ class _AsyncLLMEngine(LLMEngine):
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
@@ -317,6 +330,11 @@ class _AsyncLLMEngine(LLMEngine):
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)
@@ -325,6 +343,9 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(self.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
output = []
# Finish the current step for all the sequence groups.
@@ -337,19 +358,32 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Cache results in engine
self.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
else:
request_outputs = []
self.request_outputs = []
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
return request_outputs
return self.request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""