[Core] Optimize Async + Multi-step (#8050)
This commit is contained in:
committed by
GitHub
parent
95a178f861
commit
6d646d08a2
@@ -280,40 +280,27 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
allow_async_output_proc = cached_outputs.allow_async_output_proc
|
||||
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
|
||||
ctx = self.scheduler_contexts[virtual_engine]
|
||||
|
||||
# Clear outputs for each new scheduler iteration
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
|
||||
# Clear outputs on scheduler iteration start
|
||||
ctx.request_outputs.clear()
|
||||
|
||||
# Schedule iteration
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
|
||||
# Detect async + multi-step
|
||||
use_async_and_multi_step = (self.scheduler_config.is_multi_step
|
||||
and allow_async_output_proc)
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
|
||||
# For async + multi-step, init the queue
|
||||
if use_async_and_multi_step:
|
||||
assert len(ctx.output_queue) == 0
|
||||
assert seq_group_metadata_list is not None
|
||||
ctx.output_queue.append(
|
||||
(None, seq_group_metadata_list, scheduler_outputs))
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
@@ -351,26 +338,20 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
|
||||
if allow_async_output_proc:
|
||||
async_callback = self.async_callback_multi_step[
|
||||
virtual_engine] if use_async_and_multi_step \
|
||||
else self.async_callback[virtual_engine]
|
||||
|
||||
execute_model_req.async_callback = async_callback
|
||||
execute_model_req.use_async_and_multi_step = \
|
||||
use_async_and_multi_step
|
||||
execute_model_req.async_callback = self.async_callbacks[
|
||||
virtual_engine]
|
||||
|
||||
# Execute the model.
|
||||
output = await self.model_executor.execute_model_async(
|
||||
execute_model_req)
|
||||
|
||||
# we need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
else:
|
||||
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
output = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
@@ -384,24 +365,22 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
|
||||
if use_async_and_multi_step:
|
||||
# For async + multi-step, clear the queue
|
||||
ctx.output_queue.clear()
|
||||
else:
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs))
|
||||
is_async = allow_async_output_proc
|
||||
is_last_step = True
|
||||
ctx.output_queue.append(
|
||||
(output, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step))
|
||||
|
||||
if output and allow_async_output_proc:
|
||||
assert len(
|
||||
output
|
||||
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
if output and allow_async_output_proc:
|
||||
assert len(
|
||||
output
|
||||
) == 1, "Async postprocessor expects only a single output set"
|
||||
self._advance_to_next_step(
|
||||
output[0], seq_group_metadata_list,
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
|
||||
if not allow_async_output_proc:
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=False)
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
@@ -411,17 +390,12 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
else:
|
||||
# Multi-step case
|
||||
if use_async_and_multi_step:
|
||||
return []
|
||||
else:
|
||||
ctx.request_outputs = []
|
||||
return ctx.request_outputs
|
||||
|
||||
if not self.has_unfinished_requests():
|
||||
# Drain async postprocessor (if exists)
|
||||
if len(ctx.output_queue) > 0:
|
||||
assert not self.scheduler_config.is_multi_step
|
||||
self._process_model_outputs(virtual_engine=virtual_engine,
|
||||
is_async=True)
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
assert len(ctx.output_queue) == 0
|
||||
|
||||
return ctx.request_outputs
|
||||
@@ -640,6 +614,17 @@ class AsyncLLMEngine:
|
||||
self.log_requests = log_requests
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
# This ensures quick processing of request outputs
|
||||
# so the append to asyncio queues is not delayed,
|
||||
# especially for multi-step.
|
||||
#
|
||||
# TODO: Currently, disabled for engine_use_ray, ask
|
||||
# Cody/Will/Woosuk about this case.
|
||||
self.use_process_request_outputs_callback = not self.engine_use_ray
|
||||
if self.use_process_request_outputs_callback:
|
||||
self.engine.process_request_outputs_callback = \
|
||||
self.process_request_outputs
|
||||
|
||||
if self.engine_use_ray:
|
||||
print_warning_once(
|
||||
"DEPRECATED. `--engine-use-ray` is deprecated and will "
|
||||
@@ -883,13 +868,27 @@ class AsyncLLMEngine:
|
||||
request_outputs = await self.engine.step_async(virtual_engine)
|
||||
|
||||
# Put the outputs into the corresponding streams.
|
||||
finished = True
|
||||
# If used as a callback, then already invoked inside
|
||||
# LLMEngine's _process_model_outputs
|
||||
if not self.use_process_request_outputs_callback:
|
||||
all_finished = self.process_request_outputs(request_outputs)
|
||||
else:
|
||||
# For callback case, we only need to detect when all
|
||||
# requests are finished
|
||||
all_finished = all(request_output.finished
|
||||
for request_output in request_outputs)
|
||||
|
||||
return not all_finished
|
||||
|
||||
def process_request_outputs(self, request_outputs) -> bool:
|
||||
# Put the outputs into the corresponding streams.
|
||||
all_finished = True
|
||||
for request_output in request_outputs:
|
||||
self._request_tracker.process_request_output(
|
||||
request_output, verbose=self.log_requests)
|
||||
finished = finished and request_output.finished
|
||||
all_finished = all_finished and request_output.finished
|
||||
|
||||
return not finished
|
||||
return all_finished
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
|
||||
Reference in New Issue
Block a user