[Core] Async_output_proc: Add virtual engine support (towards pipeline parallel) (#7911)

This commit is contained in:
Alexander Matveev
2024-08-28 03:02:30 -04:00
committed by GitHub
parent 51f86bf487
commit f508e03e7f
6 changed files with 123 additions and 68 deletions

View File

@@ -279,10 +279,16 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
ctx = self.scheduler_contexts[virtual_engine]
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
@@ -290,8 +296,9 @@ class _AsyncLLMEngine(LLMEngine):
# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
@@ -332,8 +339,8 @@ class _AsyncLLMEngine(LLMEngine):
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
execute_model_req.async_callback = self.async_callback[
virtual_engine]
# Execute the model.
output = await self.model_executor.execute_model_async(
@@ -343,9 +350,10 @@ class _AsyncLLMEngine(LLMEngine):
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(self.output_queue) > 0:
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
output = []
# Finish the current step for all the sequence groups.
@@ -360,7 +368,7 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine] = SchedulerOutputState()
# Cache results in engine
self.output_queue.append(
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if output and allow_async_output_proc:
@@ -372,7 +380,8 @@ class _AsyncLLMEngine(LLMEngine):
scheduler_outputs.scheduled_seq_groups)
if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)
# Log stats.
self.do_log_stats(scheduler_outputs, output)
@@ -381,9 +390,17 @@ class _AsyncLLMEngine(LLMEngine):
self.do_tracing(scheduler_outputs)
else:
self.request_outputs = []
ctx.request_outputs = []
return self.request_outputs
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
assert len(ctx.output_queue) == 0
return ctx.request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""