[Core] Consolidate prompt arguments to LLM engines (#4328)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
import time
|
||||
from typing import Iterable, List, Optional, Type, Union
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Type, TypeVar, Union
|
||||
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer
|
||||
|
||||
@@ -18,6 +21,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
@@ -25,8 +29,8 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
MultiModalData, PoolerOutput, SamplerOutput,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
PoolerOutput, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
@@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
|
||||
return {}
|
||||
|
||||
|
||||
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
|
||||
@@ -60,11 +67,11 @@ class LLMEngine:
|
||||
iteration-level scheduling and efficient memory management to maximize the
|
||||
serving throughput.
|
||||
|
||||
The `LLM` class wraps this class for offline batched inference and the
|
||||
`AsyncLLMEngine` class wraps this class for online serving.
|
||||
The :class:`~vllm.LLM` class wraps this class for offline batched inference
|
||||
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
|
||||
|
||||
NOTE: The config arguments are derived from the `EngineArgs` class. For the
|
||||
comprehensive list of arguments, see `EngineArgs`.
|
||||
NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs`
|
||||
class. For the comprehensive list of arguments, see :ref:`engine_args`.
|
||||
|
||||
Args:
|
||||
model_config: The configuration related to the LLM model.
|
||||
@@ -81,9 +88,60 @@ class LLMEngine:
|
||||
executor_class: The model executor class for managing distributed
|
||||
execution.
|
||||
log_stats: Whether to log statistics.
|
||||
usage_context: Specified entry point, used for usage info collection
|
||||
usage_context: Specified entry point, used for usage info collection.
|
||||
"""
|
||||
|
||||
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
|
||||
"""A flag to toggle whether to validate the type of request output."""
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def enable_output_validation(cls):
|
||||
cls.DO_VALIDATE_OUTPUT = True
|
||||
|
||||
yield
|
||||
|
||||
cls.DO_VALIDATE_OUTPUT = False
|
||||
|
||||
@classmethod
|
||||
def validate_output(
|
||||
cls,
|
||||
output: object,
|
||||
output_type: Type[_O],
|
||||
) -> _O:
|
||||
do_validate = cls.DO_VALIDATE_OUTPUT
|
||||
|
||||
if ((TYPE_CHECKING or do_validate)
|
||||
and not isinstance(output, output_type)):
|
||||
raise TypeError(f"Expected output of type {output_type}, "
|
||||
f"but found type {type(output)}")
|
||||
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(
|
||||
cls,
|
||||
outputs: GenericSequence[object],
|
||||
output_type: Type[_O],
|
||||
) -> List[_O]:
|
||||
do_validate = cls.DO_VALIDATE_OUTPUT
|
||||
|
||||
outputs_: List[_O]
|
||||
if TYPE_CHECKING or do_validate:
|
||||
outputs_ = []
|
||||
for output in outputs:
|
||||
if not isinstance(output, output_type):
|
||||
raise TypeError(f"Expected output of type {output_type}, "
|
||||
f"but found type {type(output)}")
|
||||
|
||||
outputs_.append(output)
|
||||
else:
|
||||
outputs_ = outputs
|
||||
|
||||
return outputs_
|
||||
|
||||
tokenizer: Optional[BaseTokenizerGroup]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
@@ -151,12 +209,11 @@ class LLMEngine:
|
||||
self.log_stats = log_stats
|
||||
|
||||
if not self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer: BaseTokenizerGroup
|
||||
self._init_tokenizer()
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
else:
|
||||
self.detokenizer = None
|
||||
self.tokenizer = None
|
||||
self.detokenizer = None
|
||||
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
@@ -318,14 +375,26 @@ class LLMEngine:
|
||||
if model_executor := getattr(self, "model_executor", None):
|
||||
model_executor.shutdown()
|
||||
|
||||
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
def get_tokenizer_group(
|
||||
self,
|
||||
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(fail_msg)
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||
return self.tokenizer.get_lora_tokenizer(None)
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(None)
|
||||
|
||||
def get_tokenizer_for_seq(self,
|
||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(
|
||||
sequence.lora_request)
|
||||
|
||||
def _init_tokenizer(self, **tokenizer_init_kwargs):
|
||||
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
|
||||
init_kwargs = dict(
|
||||
tokenizer_id=self.model_config.tokenizer,
|
||||
enable_lora=bool(self.lora_config),
|
||||
@@ -335,8 +404,9 @@ class LLMEngine:
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
revision=self.model_config.tokenizer_revision)
|
||||
init_kwargs.update(tokenizer_init_kwargs)
|
||||
self.tokenizer = get_tokenizer_group(
|
||||
self.parallel_config.tokenizer_pool_config, **init_kwargs)
|
||||
|
||||
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
|
||||
**init_kwargs)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
@@ -346,29 +416,85 @@ class LLMEngine:
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
|
||||
def encode_request(
|
||||
def _get_eos_token_id(
|
||||
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for EOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
return None
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
||||
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str, # pylint: disable=unused-argument
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
request_id: str,
|
||||
processed_inputs: LLMInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self._get_eos_token_id(lora_request)
|
||||
|
||||
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
||||
lora_request)
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
seq_group = self._create_sequence_group_with_sampling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
):
|
||||
if prompt_token_ids is None:
|
||||
assert prompt is not None
|
||||
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
return prompt_token_ids
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
|
||||
if "prompt_token_ids" not in inputs:
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
prompt_token_ids = tokenizer.encode(request_id=request_id,
|
||||
prompt=inputs["prompt"],
|
||||
lora_request=lora_request)
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
@@ -378,15 +504,14 @@ class LLMEngine:
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
params: Parameters for sampling or pooling. SamplingParams
|
||||
for text generation. PoolingParams for pooling.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
params: Parameters for sampling or pooling.
|
||||
:class:`~vllm.SamplingParams` for text generation.
|
||||
:class:`~vllm.PoolingParams` for pooling.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current monotonic time.
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Details:
|
||||
- Set arrival_time to the current time if it is None.
|
||||
@@ -417,59 +542,26 @@ class LLMEngine:
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
prompt_token_ids = self.encode_request(
|
||||
|
||||
processed_inputs = self.process_model_inputs(request_id=request_id,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = None
|
||||
if self.tokenizer:
|
||||
eos_token_id = self.tokenizer.get_lora_tokenizer(
|
||||
lora_request).eos_token_id
|
||||
else:
|
||||
logger.warning("Use None for EOS token id because tokenizer is "
|
||||
"not initialized")
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||
eos_token_id, lora_request)
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
seq_group = self._create_sequence_group_with_sampling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def _create_sequence_group_with_sampling(
|
||||
self,
|
||||
request_id: str,
|
||||
seq: Sequence,
|
||||
sampling_params: SamplingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with SamplingParams."""
|
||||
max_logprobs = self.get_model_config().max_logprobs
|
||||
@@ -495,8 +587,7 @@ class LLMEngine:
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
lora_request=lora_request)
|
||||
|
||||
return seq_group
|
||||
|
||||
@@ -505,9 +596,8 @@ class LLMEngine:
|
||||
request_id: str,
|
||||
seq: Sequence,
|
||||
pooling_params: PoolingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with PoolingParams."""
|
||||
# Defensive copy of PoolingParams, which are used by the pooler
|
||||
@@ -517,7 +607,6 @@ class LLMEngine:
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
pooling_params=pooling_params)
|
||||
return seq_group
|
||||
|
||||
@@ -570,7 +659,7 @@ class LLMEngine:
|
||||
|
||||
def _process_model_outputs(
|
||||
self,
|
||||
output: List[Union[SamplerOutput, PoolerOutput]],
|
||||
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
|
||||
scheduled_seq_groups: List[ScheduledSequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
@@ -585,7 +674,7 @@ class LLMEngine:
|
||||
# Organize outputs by [sequence group][step] instead of
|
||||
# [step][sequence group].
|
||||
output_by_sequence_group = create_output_by_sequence_group(
|
||||
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
||||
output, num_seq_groups=len(scheduled_seq_groups))
|
||||
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
for scheduled_seq_group, outputs, seq_group_meta in zip(
|
||||
|
||||
Reference in New Issue
Block a user