[DOC] Add additional comments for LLMEngine and AsyncLLMEngine (#1011)
This commit is contained in:
@@ -257,7 +257,26 @@ class LLMEngine:
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
def _init_cache(self) -> None:
|
||||
"""Profiles the memory usage and initializes the KV cache."""
|
||||
"""Profiles the memory usage and initializes the KV cache.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
More details can be found in the
|
||||
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
|
||||
from class :class:`~vllm.worker.Worker`.
|
||||
|
||||
Afterwards, as there may be multiple workers,
|
||||
we take the minimum number of blocks across all workers
|
||||
to ensure this can be applied to all of them.
|
||||
|
||||
Finally, the engine will initialize the KV cache
|
||||
with the calculated number of blocks.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameters.
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers(
|
||||
"profile_num_available_blocks",
|
||||
@@ -334,6 +353,30 @@ class LLMEngine:
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current monotonic time.
|
||||
|
||||
Details:
|
||||
- Set arrival_time to the current time if it is None.
|
||||
- Set prompt_token_ids to the encoded prompt if it is None.
|
||||
- Create `best_of` number of :class:`~vllm.Sequence` objects.
|
||||
- Create a :class:`~vllm.SequenceGroup` object
|
||||
from the list of :class:`~vllm.Sequence`.
|
||||
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
|
||||
|
||||
Example:
|
||||
>>> # initialize engine
|
||||
>>> engine = LLMEngine.from_engine_args(engine_args)
|
||||
>>> # set request arguments
|
||||
>>> example_prompt = "Who is the president of the United States?"
|
||||
>>> sampling_params = SamplingParams(temperature=0.0)
|
||||
>>> request_id = 0
|
||||
>>>
|
||||
>>> # add the request to the engine
|
||||
>>> engine.add_request(
|
||||
>>> str(request_id),
|
||||
>>> example_prompt,
|
||||
>>> SamplingParams(temperature=0.0))
|
||||
>>> # continue the request processing
|
||||
>>> ...
|
||||
"""
|
||||
if arrival_time is None:
|
||||
arrival_time = time.monotonic()
|
||||
@@ -358,6 +401,17 @@ class LLMEngine:
|
||||
|
||||
Args:
|
||||
request_id: The ID(s) of the request to abort.
|
||||
|
||||
Details:
|
||||
- Refer to the
|
||||
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
|
||||
from class :class:`~vllm.core.scheduler.Scheduler`.
|
||||
|
||||
Example:
|
||||
>>> # initialize engine and add a request with request_id
|
||||
>>> request_id = str(0)
|
||||
>>> # abort the request
|
||||
>>> engine.abort_request(request_id)
|
||||
"""
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
|
||||
@@ -617,11 +671,53 @@ class LLMEngine:
|
||||
def step(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
This function performs one decoding iteration of the engine. It first
|
||||
schedules the sequences to be executed in the next iteration and the
|
||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
.. figure:: https://i.imgur.com/sv2HssD.png
|
||||
:alt: Overview of the step function
|
||||
:align: center
|
||||
|
||||
Overview of the step function.
|
||||
|
||||
Details:
|
||||
- Step 1: Schedules the sequences to be executed in the next
|
||||
iteration and the token blocks to be swapped in/out/copy.
|
||||
|
||||
- Depending on the scheduling policy,
|
||||
sequences may be `preempted/reordered`.
|
||||
- A Sequence Group (SG) refer to a group of sequences
|
||||
that are generated from the same prompt.
|
||||
|
||||
- Step 2: Calls the workers to execute the model.
|
||||
- Step 3: Processes the model output. This mainly includes:
|
||||
|
||||
- Decodes the relevant outputs.
|
||||
- Updates the scheduled sequence groups with model outputs
|
||||
based on its `sampling parameters` (`use_beam_search` or not).
|
||||
- Frees the finished sequence groups.
|
||||
|
||||
- Finally, it creates and returns the newly generated results.
|
||||
|
||||
Example:
|
||||
>>> # Please see the example/ folder for more detailed examples.
|
||||
>>>
|
||||
>>> # initialize engine and request arguments
|
||||
>>> engine = LLMEngine.from_engine_args(engine_args)
|
||||
>>> example_inputs = [(0, "What is LLM?",
|
||||
>>> SamplingParams(temperature=0.0))]
|
||||
>>>
|
||||
>>> # Start the engine with an event loop
|
||||
>>> while True:
|
||||
>>> if example_inputs:
|
||||
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
|
||||
>>> engine.add_request(str(req_id), prompt, sampling_params)
|
||||
>>>
|
||||
>>> # continue the request processing
|
||||
>>> request_outputs = engine.step()
|
||||
>>> for request_output in request_outputs:
|
||||
>>> if request_output.finished:
|
||||
>>> # return or show the request output
|
||||
>>>
|
||||
>>> if not (engine.has_unfinished_requests() or example_inputs):
|
||||
>>> break
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user