[Core] Consolidate prompt arguments to LLM engines (#4328)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-05-29 04:29:31 +08:00
committed by GitHub
parent 290f4ada2b
commit 5ae5ed1e60
43 changed files with 1407 additions and 442 deletions

View File

@@ -1,5 +1,8 @@
import time
from typing import Iterable, List, Optional, Type, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer
@@ -18,6 +21,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
@@ -25,8 +29,8 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
MultiModalData, PoolerOutput, SamplerOutput,
Sequence, SequenceGroup, SequenceGroupMetadata,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
@@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
return {}
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
@@ -60,11 +67,11 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
The :class:`~vllm.LLM` class wraps this class for offline batched inference
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs`
class. For the comprehensive list of arguments, see :ref:`engine_args`.
Args:
model_config: The configuration related to the LLM model.
@@ -81,9 +88,60 @@ class LLMEngine:
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection
usage_context: Specified entry point, used for usage info collection.
"""
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
"""A flag to toggle whether to validate the type of request output."""
@classmethod
@contextmanager
def enable_output_validation(cls):
cls.DO_VALIDATE_OUTPUT = True
yield
cls.DO_VALIDATE_OUTPUT = False
@classmethod
def validate_output(
cls,
output: object,
output_type: Type[_O],
) -> _O:
do_validate = cls.DO_VALIDATE_OUTPUT
if ((TYPE_CHECKING or do_validate)
and not isinstance(output, output_type)):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
@classmethod
def validate_outputs(
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
) -> List[_O]:
do_validate = cls.DO_VALIDATE_OUTPUT
outputs_: List[_O]
if TYPE_CHECKING or do_validate:
outputs_ = []
for output in outputs:
if not isinstance(output, output_type):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
outputs_.append(output)
else:
outputs_ = outputs
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
def __init__(
self,
model_config: ModelConfig,
@@ -151,12 +209,11 @@ class LLMEngine:
self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None
self.detokenizer = None
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
@@ -318,14 +375,26 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
"skip_tokenizer_init is True")
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)
return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs):
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
@@ -335,8 +404,9 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
@@ -346,29 +416,85 @@ class LLMEngine:
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def encode_request(
def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _add_processed_request(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
request_id: str,
processed_inputs: LLMInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def process_model_inputs(
self,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
def add_request(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
"""Add a request to the engine's request pool.
@@ -378,15 +504,14 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
params: Parameters for sampling or pooling. SamplingParams
for text generation. PoolingParams for pooling.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
:class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
multi_modal_data: Multi modal data per request.
Details:
- Set arrival_time to the current time if it is None.
@@ -417,59 +542,26 @@ class LLMEngine:
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = self.encode_request(
processed_inputs = self.process_model_inputs(request_id=request_id,
inputs=inputs,
lora_request=lora_request)
self._add_processed_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
def _create_sequence_group_with_sampling(
self,
request_id: str,
seq: Sequence,
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@@ -495,8 +587,7 @@ class LLMEngine:
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
lora_request=lora_request)
return seq_group
@@ -505,9 +596,8 @@ class LLMEngine:
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
@@ -517,7 +607,6 @@ class LLMEngine:
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
pooling_params=pooling_params)
return seq_group
@@ -570,7 +659,7 @@ class LLMEngine:
def _process_model_outputs(
self,
output: List[Union[SamplerOutput, PoolerOutput]],
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -585,7 +674,7 @@ class LLMEngine:
# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(