[V1] V1 Enablement Oracle (#13726)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@@ -49,6 +50,12 @@ class AsyncLLM(EngineClient):
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
) -> None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
assert start_engine_loop
|
||||
|
||||
@@ -92,22 +99,50 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
|
||||
disable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
) -> "AsyncLLM":
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
# FIXME(rob): refactor VllmConfig to include the StatLoggers
|
||||
# include StatLogger in the Oracle decision.
|
||||
if stat_loggers is not None:
|
||||
raise ValueError("Custom StatLoggers are not yet supported on V1. "
|
||||
"Explicitly set VLLM_USE_V1=0 to disable V1.")
|
||||
|
||||
# Create the LLMEngine.
|
||||
return cls(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
start_engine_loop=start_engine_loop,
|
||||
log_requests=not disable_log_requests,
|
||||
log_stats=not disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
engine_config: Optional[VllmConfig] = None,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
) -> "AsyncLLM":
|
||||
"""Create an AsyncLLM from the EngineArgs."""
|
||||
|
||||
# Create the engine configs.
|
||||
if engine_config is None:
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
else:
|
||||
vllm_config = engine_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
# Create the AsyncLLM.
|
||||
|
||||
@@ -46,6 +46,13 @@ class LLMEngine:
|
||||
use_cached_outputs: bool = False,
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
|
||||
"This should not happen. As a workaround, try using "
|
||||
"LLMEngine.from_vllm_config(...) or explicitly set "
|
||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
@@ -88,6 +95,26 @@ class LLMEngine:
|
||||
# for v0 compatibility
|
||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
|
||||
disable_log_stats: bool = False,
|
||||
) -> "LLMEngine":
|
||||
if stat_loggers is not None:
|
||||
raise NotImplementedError(
|
||||
"Passing StatLoggers to V1 is not yet supported. "
|
||||
"Set VLLM_USE_V1=0 and file and issue on Github.")
|
||||
|
||||
return cls(vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=(not disable_log_stats),
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
|
||||
@@ -184,7 +184,7 @@ class Processor:
|
||||
# Only applicable to multimodal models with legacy input processor.
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
decoder_inputs = SingletonInputsAdapter(
|
||||
@@ -200,8 +200,12 @@ class Processor:
|
||||
raise NotImplementedError
|
||||
|
||||
assert isinstance(params, SamplingParams)
|
||||
# TODO: can we avoid cloning here in multiproc case
|
||||
# TODO: can we avoid cloning here in multiproc case?
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
sampling_params.max_tokens = (self.model_config.max_model_len -
|
||||
len(decoder_inputs.prompt_token_ids))
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
sampling_params.update_from_tokenizer(
|
||||
@@ -296,7 +300,9 @@ class Processor:
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs):
|
||||
def _validate_model_inputs(self,
|
||||
inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest] = None):
|
||||
if is_encoder_decoder_inputs(inputs):
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
@@ -310,6 +316,13 @@ class Processor:
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
max_input_id = max(prompt_ids)
|
||||
max_allowed = self.tokenizer.get_lora_tokenizer(
|
||||
lora_request).max_token_id
|
||||
if max_input_id > max_allowed:
|
||||
raise ValueError(
|
||||
"Token id {} is out of vocabulary".format(max_input_id))
|
||||
|
||||
if len(prompt_ids) >= self.model_config.max_model_len:
|
||||
raise ValueError(
|
||||
f"Prompt length of {len(prompt_ids)} is longer than the "
|
||||
|
||||
Reference in New Issue
Block a user