[V1] V1 Enablement Oracle (#13726)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -223,15 +223,6 @@ class EngineArgs:
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
|
||||
# Override the default value of enable_prefix_caching if it's not set
|
||||
# by user.
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
|
||||
|
||||
# Override max_num_seqs if it's not set by user.
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
|
||||
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
# without having to manually construct a
|
||||
# CompilationConfig object
|
||||
@@ -246,7 +237,6 @@ class EngineArgs:
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
@@ -1191,24 +1181,51 @@ class EngineArgs:
|
||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||
)
|
||||
|
||||
def create_engine_config(self,
|
||||
usage_context: Optional[UsageContext] = None
|
||||
) -> VllmConfig:
|
||||
def create_engine_config(
|
||||
self,
|
||||
usage_context: Optional[UsageContext] = None,
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create the VllmConfig.
|
||||
|
||||
NOTE: for autoselection of V0 vs V1 engine, we need to
|
||||
create the ModelConfig first, since ModelConfig's attrs
|
||||
(e.g. the model arch) are needed to make the decision.
|
||||
|
||||
This function set VLLM_USE_V1=X if VLLM_USE_V1 is
|
||||
unspecified by the user.
|
||||
|
||||
If VLLM_USE_V1 is specified by the user but the VllmConfig
|
||||
is incompatible, we raise an error.
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.pre_register_and_update()
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
self._override_v1_engine_args(usage_context)
|
||||
|
||||
device_config = DeviceConfig(device=self.device)
|
||||
model_config = self.create_model_config()
|
||||
|
||||
if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
|
||||
and self.enable_prefix_caching):
|
||||
logger.warning("--enable-prefix-caching is currently not "
|
||||
"supported for multimodal models in v0 and "
|
||||
"has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
|
||||
# and fall back to V0 for experimental or unsupported features.
|
||||
# * If VLLM_USE_V1=1, we enable V1 for supported + experimental
|
||||
# features and raise error for unsupported features.
|
||||
# * If VLLM_USE_V1=0, we disable V1.
|
||||
use_v1 = False
|
||||
try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
|
||||
if try_v1 and self._is_v1_supported_oracle(model_config):
|
||||
use_v1 = True
|
||||
|
||||
# If user explicitly set VLLM_USE_V1, sanity check we respect it.
|
||||
if envs.is_set("VLLM_USE_V1"):
|
||||
assert use_v1 == envs.VLLM_USE_V1
|
||||
# Otherwise, set the VLLM_USE_V1 variable globally.
|
||||
else:
|
||||
envs.set_vllm_use_v1(use_v1)
|
||||
|
||||
# Set default arguments for V0 or V1 Engine.
|
||||
if use_v1:
|
||||
self._set_default_args_v1(usage_context)
|
||||
else:
|
||||
self._set_default_args_v0(model_config)
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
@@ -1239,50 +1256,6 @@ class EngineArgs:
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
)
|
||||
|
||||
max_model_len = model_config.max_model_len
|
||||
use_long_context = max_model_len > 32768
|
||||
if self.enable_chunked_prefill is None:
|
||||
# If not explicitly set, enable chunked prefill by default for
|
||||
# long context (> 32K) models. This is to avoid OOM errors in the
|
||||
# initial memory profiling phase.
|
||||
|
||||
# For multimodal models and models with MLA, chunked prefill is
|
||||
# disabled by default in V0, but enabled by design in V1
|
||||
if model_config.is_multimodal_model or model_config.use_mla:
|
||||
self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)
|
||||
|
||||
elif use_long_context:
|
||||
is_gpu = device_config.device_type == "cuda"
|
||||
use_sliding_window = (model_config.get_sliding_window()
|
||||
is not None)
|
||||
use_spec_decode = self.speculative_model is not None
|
||||
from vllm.platforms import current_platform
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and not self.enable_prompt_adapter
|
||||
and model_config.runner_type != "pooling"
|
||||
and not current_platform.is_rocm()):
|
||||
self.enable_chunked_prefill = True
|
||||
logger.warning(
|
||||
"Chunked prefill is enabled by default for models with "
|
||||
"max_model_len > 32K. Currently, chunked prefill might "
|
||||
"not work with some features or models. If you "
|
||||
"encounter any issues, please disable chunked prefill "
|
||||
"by setting --enable-chunked-prefill=False.")
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = False
|
||||
|
||||
if not self.enable_chunked_prefill and use_long_context:
|
||||
logger.warning(
|
||||
"The model has a long context length (%s). This may cause OOM "
|
||||
"errors during the initial memory profiling phase, or result "
|
||||
"in low performance due to small KV cache space. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
elif (self.enable_chunked_prefill
|
||||
and model_config.runner_type == "pooling"):
|
||||
msg = "Chunked prefill is not supported for pooling models"
|
||||
raise ValueError(msg)
|
||||
|
||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
@@ -1425,18 +1398,282 @@ class EngineArgs:
|
||||
additional_config=self.additional_config,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
self._override_v1_engine_config(config)
|
||||
return config
|
||||
|
||||
def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
|
||||
"""
|
||||
Override the EngineArgs's args based on the usage context for V1.
|
||||
"""
|
||||
assert envs.VLLM_USE_V1, "V1 is not enabled"
|
||||
def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
|
||||
"""Oracle for whether to use V0 or V1 Engine by default."""
|
||||
|
||||
#############################################################
|
||||
# Unsupported Feature Flags on V1.
|
||||
|
||||
if (self.load_format == LoadFormat.TENSORIZER.value
|
||||
or self.load_format == LoadFormat.SHARDED_STATE.value):
|
||||
_raise_or_fallback(
|
||||
feature_name=f"--load_format {self.load_format}",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
if (self.logits_processor_pattern
|
||||
!= EngineArgs.logits_processor_pattern):
|
||||
_raise_or_fallback(feature_name="--logits-processor-pattern",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
if self.preemption_mode != EngineArgs.preemption_mode:
|
||||
_raise_or_fallback(feature_name="--preemption-mode",
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
if (self.disable_async_output_proc
|
||||
!= EngineArgs.disable_async_output_proc):
|
||||
_raise_or_fallback(feature_name="--disable-async-output-proc",
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
if self.scheduling_policy != EngineArgs.scheduling_policy:
|
||||
_raise_or_fallback(feature_name="--scheduling-policy",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
if self.worker_cls != EngineArgs.worker_cls:
|
||||
_raise_or_fallback(feature_name="--worker-cls",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
if self.worker_extension_cls != EngineArgs.worker_extension_cls:
|
||||
_raise_or_fallback(feature_name="--worker-extension-cls",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
|
||||
_raise_or_fallback(feature_name="--num-scheduler-steps",
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor:
|
||||
_raise_or_fallback(feature_name="--scheduler-delay-factor",
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
if self.additional_config != EngineArgs.additional_config:
|
||||
_raise_or_fallback(feature_name="--additional-config",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# Only support Xgrammar for guided decoding so far.
|
||||
SUPPORTED_GUIDED_DECODING = ["xgrammar", "xgrammar:nofallback"]
|
||||
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
|
||||
_raise_or_fallback(feature_name="--guided-decoding-backend",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# Need at least Ampere for now (FA support required).
|
||||
from vllm.platforms import current_platform
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.get_device_capability().major < 8):
|
||||
_raise_or_fallback(feature_name="Compute Capability < 8.0",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Fp8 KV cache so far.
|
||||
if self.kv_cache_dtype != "auto":
|
||||
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Prompt Adapter so far.
|
||||
if self.enable_prompt_adapter:
|
||||
_raise_or_fallback(feature_name="--enable-prompt-adapter",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No MistralTokenizer support so far (not compatible
|
||||
# with xgrammar)
|
||||
if model_config.tokenizer_mode == "mistral":
|
||||
_raise_or_fallback(feature_name="--tokenizer-mode mistral",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No CPU offloading yet.
|
||||
if self.cpu_offload_gb != EngineArgs.cpu_offload_gb:
|
||||
_raise_or_fallback(feature_name="--cpu-offload-gb",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# Only Fp16 and Bf16 dtypes since we only support FA.
|
||||
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
|
||||
if model_config.dtype not in V1_SUPPORTED_DTYPES:
|
||||
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# Some quantization is not compatible with torch.compile.
|
||||
V1_UNSUPPORTED_QUANT = ["bitsandbytes", "gguf"]
|
||||
if model_config.quantization in V1_UNSUPPORTED_QUANT:
|
||||
_raise_or_fallback(
|
||||
feature_name=f"--quantization {model_config.quantization}",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Embedding Models so far.
|
||||
if model_config.task not in ["generate"]:
|
||||
_raise_or_fallback(feature_name=f"--task {model_config.task}",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Mamba or Encoder-Decoder so far.
|
||||
if not model_config.is_v1_compatible:
|
||||
_raise_or_fallback(feature_name=model_config.architectures,
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No TransformersModel support so far.
|
||||
if (model_config.model_impl == ModelImpl.TRANSFORMERS
|
||||
or model_config.model_impl == "transformers"):
|
||||
_raise_or_fallback(
|
||||
feature_name=f"model_impl={model_config.model_impl}",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Concurrent Partial Prefills so far.
|
||||
if (self.max_num_partial_prefills
|
||||
!= EngineArgs.max_num_partial_prefills
|
||||
or self.max_long_partial_prefills
|
||||
!= EngineArgs.max_long_partial_prefills
|
||||
or self.long_prefill_token_threshold
|
||||
!= EngineArgs.long_prefill_token_threshold):
|
||||
_raise_or_fallback(feature_name="Concurrent Partial Prefill",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No OTLP observability so far.
|
||||
if (self.otlp_traces_endpoint or self.collect_detailed_traces):
|
||||
_raise_or_fallback(feature_name="--otlp-traces-endpoint",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# Only Ngram speculative decoding so far.
|
||||
if (self.speculative_model is not None
|
||||
or self.num_speculative_tokens is not None):
|
||||
# This is supported but experimental (handled below).
|
||||
if self.speculative_model == "[ngram]":
|
||||
pass
|
||||
else:
|
||||
_raise_or_fallback(feature_name="Speculative Decoding",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Disaggregated Prefill so far.
|
||||
if self.kv_transfer_config != EngineArgs.kv_transfer_config:
|
||||
_raise_or_fallback(feature_name="--kv-transfer-config",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No FlashInfer or XFormers so far.
|
||||
V1_BACKENDS = [
|
||||
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
|
||||
"TRITON_MLA", "FLASHMLA"
|
||||
]
|
||||
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
||||
name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
|
||||
_raise_or_fallback(feature_name=name, recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
#############################################################
|
||||
# Experimental Features - allow users to opt in.
|
||||
|
||||
# MLA is is supported on V1, but off by default for now.
|
||||
if model_config.use_mla and _warn_or_fallback("MLA"):
|
||||
return False
|
||||
|
||||
# LoRA is supported on V1, but off by default for now.
|
||||
if self.enable_lora and _warn_or_fallback("LORA"):
|
||||
return False
|
||||
|
||||
# PP is supported on V1, but off by default for now.
|
||||
if self.pipeline_parallel_size > 1 and _warn_or_fallback("PP"):
|
||||
return False
|
||||
|
||||
# ngram is supported on V1, but off by default for now.
|
||||
if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"):
|
||||
return False
|
||||
|
||||
# Non-CUDA is supported on V1, but off by default for now.
|
||||
not_cuda = not current_platform.is_cuda()
|
||||
if not_cuda and _warn_or_fallback( # noqa: SIM103
|
||||
current_platform.device_type):
|
||||
return False
|
||||
#############################################################
|
||||
|
||||
return True
|
||||
|
||||
def _set_default_args_v0(self, model_config: ModelConfig) -> None:
|
||||
"""Set Default Arguments for V0 Engine."""
|
||||
|
||||
max_model_len = model_config.max_model_len
|
||||
use_long_context = max_model_len > 32768
|
||||
if self.enable_chunked_prefill is None:
|
||||
# Chunked prefill not supported for Multimodal or MLA in V0.
|
||||
if model_config.is_multimodal_model or model_config.use_mla:
|
||||
self.enable_chunked_prefill = False
|
||||
|
||||
# Enable chunked prefill by default for long context (> 32K)
|
||||
# models to avoid OOM errors in initial memory profiling phase.
|
||||
elif use_long_context:
|
||||
from vllm.platforms import current_platform
|
||||
is_gpu = current_platform.is_cuda()
|
||||
use_sliding_window = (model_config.get_sliding_window()
|
||||
is not None)
|
||||
use_spec_decode = self.speculative_model is not None
|
||||
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and not self.enable_prompt_adapter
|
||||
and model_config.runner_type != "pooling"):
|
||||
self.enable_chunked_prefill = True
|
||||
logger.warning(
|
||||
"Chunked prefill is enabled by default for models "
|
||||
"with max_model_len > 32K. Chunked prefill might "
|
||||
"not work with some features or models. If you "
|
||||
"encounter any issues, please disable by launching "
|
||||
"with --enable-chunked-prefill=False.")
|
||||
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = False
|
||||
|
||||
if not self.enable_chunked_prefill and use_long_context:
|
||||
logger.warning(
|
||||
"The model has a long context length (%s). This may cause"
|
||||
"OOM during the initial memory profiling phase, or result "
|
||||
"in low performance due to small KV cache size. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
elif (self.enable_chunked_prefill
|
||||
and model_config.runner_type == "pooling"):
|
||||
msg = "Chunked prefill is not supported for pooling models"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Disable prefix caching for multimodal models for VLLM_V0.
|
||||
if (model_config.is_multimodal_model and self.enable_prefix_caching):
|
||||
logger.warning(
|
||||
"--enable-prefix-caching is not supported for multimodal "
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# Set max_num_seqs to 256 for VLLM_V0.
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 256
|
||||
|
||||
def _set_default_args_v1(self, usage_context: UsageContext) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
|
||||
# V1 always uses chunked prefills.
|
||||
self.enable_chunked_prefill = True
|
||||
|
||||
# V1 enables prefix caching by default.
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = True
|
||||
|
||||
# V1 should use the new scheduler by default.
|
||||
# Swap it only if this arg is set to the original V0 default
|
||||
if self.scheduler_cls == EngineArgs.scheduler_cls:
|
||||
@@ -1471,19 +1708,21 @@ class EngineArgs:
|
||||
UsageContext.OPENAI_API_SERVER: 2048,
|
||||
}
|
||||
|
||||
use_context_value = usage_context.value if usage_context else None
|
||||
if (self.max_num_batched_tokens is None
|
||||
and usage_context in default_max_num_batched_tokens):
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
||||
usage_context]
|
||||
logger.warning(
|
||||
logger.debug(
|
||||
"Setting max_num_batched_tokens to %d for %s usage context.",
|
||||
self.max_num_batched_tokens, usage_context.value)
|
||||
self.max_num_batched_tokens, use_context_value)
|
||||
|
||||
def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
|
||||
"""
|
||||
Override the EngineConfig's configs based on the usage context for V1.
|
||||
"""
|
||||
assert envs.VLLM_USE_V1, "V1 is not enabled"
|
||||
default_max_num_seqs = 1024
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = default_max_num_seqs
|
||||
|
||||
logger.debug("Setting max_num_seqs to %d for %s usage context.",
|
||||
self.max_num_seqs, use_context_value)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1508,6 +1747,33 @@ class AsyncEngineArgs(EngineArgs):
|
||||
return parser
|
||||
|
||||
|
||||
def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
|
||||
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
f"VLLM_USE_V1=1 is not supported with {feature_name}.")
|
||||
msg = f"{feature_name} is not supported by the V1 Engine. "
|
||||
msg += "Falling back to V0. "
|
||||
if recommend_to_remove:
|
||||
msg += f"We recommend to remove {feature_name} from your config "
|
||||
msg += "in favor of the V1 Engine."
|
||||
logger.warning(msg)
|
||||
|
||||
|
||||
def _warn_or_fallback(feature_name: str) -> bool:
|
||||
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
logger.warning(
|
||||
"Detected VLLM_USE_V1=1 with %s. Usage should "
|
||||
"be considered experimental. Please report any "
|
||||
"issues on Github.", feature_name)
|
||||
should_exit = False
|
||||
else:
|
||||
logger.info(
|
||||
"%s is experimental on VLLM_USE_V1=1. "
|
||||
"Falling back to V0 Engine.", feature_name)
|
||||
should_exit = True
|
||||
return should_exit
|
||||
|
||||
|
||||
# These functions are used by sphinx to build the documentation
|
||||
def _engine_args_parser():
|
||||
return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
Reference in New Issue
Block a user