[Core] Add multi-step support to LLMEngine (#7789)
This commit is contained in:
committed by
GitHub
parent
09c7792610
commit
9db93de20c
@@ -1,11 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -15,7 +13,7 @@ from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
||||
PromptComponents)
|
||||
PromptComponents, SchedulerOutputState)
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
@@ -28,8 +26,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import print_warning_once
|
||||
@@ -257,24 +254,11 @@ class RequestTracker:
|
||||
return not self._new_requests.empty()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutputState:
|
||||
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
|
||||
last_output: Optional[SamplerOutput] = None
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
pipeline_parallel_size = \
|
||||
self.parallel_config.pipeline_parallel_size
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState() for _ in range(pipeline_parallel_size)
|
||||
]
|
||||
|
||||
async def step_async(
|
||||
self, virtual_engine: int
|
||||
@@ -367,60 +351,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
return request_outputs
|
||||
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
) -> bool:
|
||||
if (not self.scheduler_config.is_multi_step
|
||||
or not seq_group_metadata_list):
|
||||
return False
|
||||
|
||||
# TODO(will) this is a sanity check for nowto make sure that all the
|
||||
# seqs are on the same steps. Eventually we will want to do some sort of
|
||||
# dynamic scheduling when doing multi-step decoding.
|
||||
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
||||
if any([
|
||||
seq_group.state.remaining_steps != ref_remaining_steps
|
||||
for seq_group in seq_group_metadata_list[1:]
|
||||
]):
|
||||
raise AssertionError(("All running sequence groups should "
|
||||
"have the same remaining steps."))
|
||||
|
||||
return ref_remaining_steps > 0
|
||||
|
||||
def _cache_scheduler_outputs_for_multi_step(
|
||||
self, virtual_engine: int,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
scheduler_outputs: SchedulerOutputs) -> None:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
|
||||
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
|
||||
scheduler_outputs
|
||||
self.cached_scheduler_outputs[virtual_engine].last_output = None
|
||||
|
||||
def _get_last_sampled_token_ids(
|
||||
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
||||
cached_last_output = self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and self.parallel_config.pipeline_parallel_size > 1
|
||||
and cached_last_output is not None
|
||||
and cached_last_output.sampled_token_ids_cpu is not None):
|
||||
return cached_last_output.sampled_token_ids_cpu
|
||||
return None
|
||||
|
||||
def _update_cached_scheduler_output(
|
||||
self, virtual_engine: int,
|
||||
output: List[Optional[SamplerOutput]]) -> None:
|
||||
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
|
||||
and output[0] is not None):
|
||||
last_output = output[-1]
|
||||
assert last_output is not None
|
||||
assert last_output.sampled_token_ids_cpu is not None
|
||||
assert last_output.sampled_token_ids is None
|
||||
assert last_output.sampled_token_probs is None
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output = last_output
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
Reference in New Issue
Block a user