Refactor AsyncLLMEngine (#880)

This commit is contained in:
Antoni Baum
2023-09-03 21:43:43 -07:00
committed by GitHub
parent bf87484efa
commit ce741ba3e4
3 changed files with 271 additions and 149 deletions

View File

@@ -1,17 +1,18 @@
import time
import copy
import time
from functools import partial
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.core.scheduler import Scheduler
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
@@ -135,7 +136,8 @@ class LLMEngine:
get_all_outputs=True,
)
def _init_workers_ray(self, placement_group: "PlacementGroup"):
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
@@ -150,6 +152,7 @@ class LLMEngine:
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True),
**ray_remote_kwargs,
)(RayWorker).remote()
self.workers.append(worker)
@@ -268,11 +271,11 @@ class LLMEngine:
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: str) -> None:
"""Aborts a request with the given ID.
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.
Args:
request_id: The ID of the request to abort.
request_id: The ID(s) of the request to abort.
"""
self.scheduler.abort_seq_group(request_id)
@@ -288,35 +291,21 @@ class LLMEngine:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
def _schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
Optional[List[RequestOutput]]]:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if scheduler_outputs.is_empty():
if not scheduler_outputs.ignored_seq_groups:
# Nothing to do.
return []
# If there are ignored seq groups, we need to return them as the
# request outputs.
return [
return seq_group_metadata_list, scheduler_outputs, [
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
]
return seq_group_metadata_list, scheduler_outputs, None
# Execute the model.
output = self._run_workers(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
def _process_worker_outputs(
self, output,
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduler with the model outputs.
seq_groups = self.scheduler.update(output)
@@ -339,6 +328,31 @@ class LLMEngine:
scheduler_outputs.num_batched_tokens)
return request_outputs
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(seq_group_metadata_list, scheduler_outputs,
early_return) = self._schedule()
if early_return is not None:
return early_return
# Execute the model.
output = self._run_workers(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
return self._process_worker_outputs(output, scheduler_outputs)
def _log_system_stats(
self,
prompt_run: bool,