Improve configs - CacheConfig (#16835)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
122
vllm/config.py
122
vllm/config.py
@@ -1245,22 +1245,70 @@ class ModelConfig:
|
|||||||
or getattr(self.hf_config, "is_matryoshka", False))
|
or getattr(self.hf_config, "is_matryoshka", False))
|
||||||
|
|
||||||
|
|
||||||
class CacheConfig:
|
BlockSize = Literal[8, 16, 32, 64, 128]
|
||||||
"""Configuration for the KV cache.
|
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
|
||||||
|
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
|
||||||
|
|
||||||
Args:
|
|
||||||
block_size: Size of a cache block in number of tokens.
|
@config
|
||||||
gpu_memory_utilization: Fraction of GPU memory to use for the
|
@dataclass
|
||||||
vLLM execution.
|
class CacheConfig:
|
||||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
"""Configuration for the KV cache."""
|
||||||
cache_dtype: Data type for kv cache storage.
|
|
||||||
is_attention_free: Whether the model is attention-free.
|
block_size: Optional[BlockSize] = None
|
||||||
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
|
"""Size of a contiguous cache block in number of tokens. This is ignored on
|
||||||
profiled num_gpu_blocks if specified. Does nothing if None.
|
neuron devices and set to `--max-model-len`. On CUDA devices, only block
|
||||||
sliding_window: Sliding window size for the KV cache.
|
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
|
||||||
enable_prefix_caching: Whether to enable prefix caching.
|
|
||||||
cpu_offload_gb: Size of the CPU offload buffer in GiB.
|
|
||||||
"""
|
"""
|
||||||
|
gpu_memory_utilization: float = 0.9
|
||||||
|
"""The fraction of GPU memory to be used for the model executor, which can
|
||||||
|
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
|
||||||
|
utilization. If unspecified, will use the default value of 0.9. This is a
|
||||||
|
per-instance limit, and only applies to the current vLLM instance. It does
|
||||||
|
not matter if you have another vLLM instance running on the same GPU. For
|
||||||
|
example, if you have two vLLM instances running on the same GPU, you can
|
||||||
|
set the GPU memory utilization to 0.5 for each instance."""
|
||||||
|
swap_space: float = 4
|
||||||
|
"""Size of the CPU swap space per GPU (in GiB)."""
|
||||||
|
cache_dtype: CacheDType = "auto"
|
||||||
|
"""Data type for kv cache storage. If "auto", will use model data type.
|
||||||
|
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
|
||||||
|
fp8 (=fp8_e4m3)."""
|
||||||
|
is_attention_free: bool = False
|
||||||
|
"""Whether the model is attention-free. This is primarily set in
|
||||||
|
`ModelConfig` and that value should be manually duplicated here."""
|
||||||
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
|
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
|
||||||
|
if specified. Does nothing if `None`. Used for testing preemption."""
|
||||||
|
sliding_window: Optional[int] = None
|
||||||
|
"""Sliding window size for the KV cache. This is primarily set in
|
||||||
|
`ModelConfig` and that value should be manually duplicated here."""
|
||||||
|
enable_prefix_caching: Optional[bool] = None
|
||||||
|
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
|
||||||
|
default for V1."""
|
||||||
|
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
|
||||||
|
"""Set the hash algorithm for prefix caching:\n
|
||||||
|
- "builtin" is Python's built-in hash.\n
|
||||||
|
- "sha256" is collision resistant but with certain overheads."""
|
||||||
|
cpu_offload_gb: float = 0
|
||||||
|
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||||
|
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||||
|
increase the GPU memory size. For example, if you have one 24 GB GPU and
|
||||||
|
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
|
||||||
|
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
|
||||||
|
Note that this requires fast CPU-GPU interconnect, as part of the model is
|
||||||
|
loaded from CPU memory to GPU memory on the fly in each model forward pass.
|
||||||
|
"""
|
||||||
|
calculate_kv_scales: bool = False
|
||||||
|
"""This enables dynamic calculation of `k_scale` and `v_scale` when
|
||||||
|
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
|
||||||
|
checkpoint if available. Otherwise, the scales will default to 1.0."""
|
||||||
|
|
||||||
|
# Will be set after profiling.
|
||||||
|
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
||||||
|
"""The number of blocks to allocate for GPU memory."""
|
||||||
|
num_cpu_blocks: Optional[int] = field(default=None, init=False)
|
||||||
|
"""The number of blocks to allocate for CPU memory."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -1281,43 +1329,13 @@ class CacheConfig:
|
|||||||
usedforsecurity=False).hexdigest()
|
usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __init__(
|
def __post_init__(self) -> None:
|
||||||
self,
|
self.swap_space_bytes = self.swap_space * GiB_bytes
|
||||||
block_size: int,
|
|
||||||
gpu_memory_utilization: float,
|
|
||||||
swap_space: float,
|
|
||||||
cache_dtype: str,
|
|
||||||
is_attention_free: bool = False,
|
|
||||||
num_gpu_blocks_override: Optional[int] = None,
|
|
||||||
sliding_window: Optional[int] = None,
|
|
||||||
enable_prefix_caching: bool = False,
|
|
||||||
prefix_caching_hash_algo: str = "builtin",
|
|
||||||
cpu_offload_gb: float = 0,
|
|
||||||
calculate_kv_scales: Optional[bool] = None,
|
|
||||||
) -> None:
|
|
||||||
self.block_size = block_size
|
|
||||||
self.gpu_memory_utilization = gpu_memory_utilization
|
|
||||||
self.swap_space_bytes = swap_space * GiB_bytes
|
|
||||||
self.num_gpu_blocks_override = num_gpu_blocks_override
|
|
||||||
self.cache_dtype = cache_dtype
|
|
||||||
self.is_attention_free = is_attention_free
|
|
||||||
self.sliding_window = sliding_window
|
|
||||||
self.enable_prefix_caching = enable_prefix_caching
|
|
||||||
self.prefix_caching_hash_algo = prefix_caching_hash_algo
|
|
||||||
self.cpu_offload_gb = cpu_offload_gb
|
|
||||||
self.calculate_kv_scales = calculate_kv_scales
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
self._verify_cache_dtype()
|
self._verify_cache_dtype()
|
||||||
self._verify_prefix_caching()
|
self._verify_prefix_caching()
|
||||||
|
|
||||||
# Will be set after profiling.
|
|
||||||
self.num_gpu_blocks: Optional[int] = None
|
|
||||||
self.num_cpu_blocks: Optional[int] = None
|
|
||||||
|
|
||||||
# Set calculate_kv_scales to False if the value is unset.
|
|
||||||
if self.calculate_kv_scales is None:
|
|
||||||
self.calculate_kv_scales = False
|
|
||||||
|
|
||||||
def metrics_info(self):
|
def metrics_info(self):
|
||||||
# convert cache_config to dict(key: str, value: str) for prometheus
|
# convert cache_config to dict(key: str, value: str) for prometheus
|
||||||
# metrics info
|
# metrics info
|
||||||
@@ -1336,7 +1354,7 @@ class CacheConfig:
|
|||||||
def _verify_cache_dtype(self) -> None:
|
def _verify_cache_dtype(self) -> None:
|
||||||
if self.cache_dtype == "auto":
|
if self.cache_dtype == "auto":
|
||||||
pass
|
pass
|
||||||
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
elif self.cache_dtype in get_args(CacheDType):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||||
"memory footprint and boosts the performance. "
|
"memory footprint and boosts the performance. "
|
||||||
@@ -1354,12 +1372,12 @@ class CacheConfig:
|
|||||||
"Prefix caching is not supported with sliding window. "
|
"Prefix caching is not supported with sliding window. "
|
||||||
"Run with --disable-sliding-window to use prefix caching.")
|
"Run with --disable-sliding-window to use prefix caching.")
|
||||||
|
|
||||||
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in (
|
if (self.enable_prefix_caching and self.prefix_caching_hash_algo
|
||||||
"builtin", "sha256"):
|
not in get_args(PrefixCachingHashAlgo)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unknown prefix caching hash algorithm: "
|
"Unknown prefix caching hash algorithm: "
|
||||||
f"{self.prefix_caching_hash_algo}. Must be either "
|
f"{self.prefix_caching_hash_algo}. Must be one of "
|
||||||
"'builtin' or 'sha256'.")
|
f"{get_args(PrefixCachingHashAlgo)}.")
|
||||||
|
|
||||||
def verify_with_parallel_config(
|
def verify_with_parallel_config(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -16,16 +16,16 @@ from typing_extensions import TypeIs
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import version
|
from vllm import version
|
||||||
from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
|
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||||
DecodingConfig, Device, DeviceConfig,
|
Config, ConfigFormat, DecodingConfig, Device,
|
||||||
DistributedExecutorBackend, HfOverrides,
|
DeviceConfig, DistributedExecutorBackend, HfOverrides,
|
||||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||||
ModelConfig, ModelImpl, MultiModalConfig,
|
ModelConfig, ModelImpl, MultiModalConfig,
|
||||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||||
PoolType, PromptAdapterConfig, SchedulerConfig,
|
PoolType, PrefixCachingHashAlgo, PromptAdapterConfig,
|
||||||
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||||
TokenizerPoolConfig, VllmConfig, get_attr_docs,
|
TaskOption, TokenizerPoolConfig, VllmConfig,
|
||||||
get_field)
|
get_attr_docs, get_field)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
@@ -138,7 +138,7 @@ class EngineArgs:
|
|||||||
load_format: str = LoadConfig.load_format
|
load_format: str = LoadConfig.load_format
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO
|
config_format: ConfigFormat = ConfigFormat.AUTO
|
||||||
dtype: str = 'auto'
|
dtype: str = 'auto'
|
||||||
kv_cache_dtype: str = 'auto'
|
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
max_model_len: Optional[int] = None
|
max_model_len: Optional[int] = None
|
||||||
# Note: Specifying a custom executor backend by passing a class
|
# Note: Specifying a custom executor backend by passing a class
|
||||||
@@ -154,15 +154,16 @@ class EngineArgs:
|
|||||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||||
max_parallel_loading_workers: Optional[
|
max_parallel_loading_workers: Optional[
|
||||||
int] = ParallelConfig.max_parallel_loading_workers
|
int] = ParallelConfig.max_parallel_loading_workers
|
||||||
block_size: Optional[int] = None
|
block_size: Optional[BlockSize] = CacheConfig.block_size
|
||||||
enable_prefix_caching: Optional[bool] = None
|
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
|
||||||
prefix_caching_hash_algo: str = "builtin"
|
prefix_caching_hash_algo: PrefixCachingHashAlgo = \
|
||||||
|
CacheConfig.prefix_caching_hash_algo
|
||||||
disable_sliding_window: bool = False
|
disable_sliding_window: bool = False
|
||||||
disable_cascade_attn: bool = False
|
disable_cascade_attn: bool = False
|
||||||
use_v2_block_manager: bool = True
|
use_v2_block_manager: bool = True
|
||||||
swap_space: float = 4 # GiB
|
swap_space: float = CacheConfig.swap_space
|
||||||
cpu_offload_gb: float = 0 # GiB
|
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
|
||||||
gpu_memory_utilization: float = 0.90
|
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
|
||||||
max_num_batched_tokens: Optional[
|
max_num_batched_tokens: Optional[
|
||||||
int] = SchedulerConfig.max_num_batched_tokens
|
int] = SchedulerConfig.max_num_batched_tokens
|
||||||
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
|
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
|
||||||
@@ -211,7 +212,8 @@ class EngineArgs:
|
|||||||
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
||||||
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
||||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||||
num_gpu_blocks_override: Optional[int] = None
|
num_gpu_blocks_override: Optional[
|
||||||
|
int] = CacheConfig.num_gpu_blocks_override
|
||||||
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
|
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
|
||||||
model_loader_extra_config: dict = \
|
model_loader_extra_config: dict = \
|
||||||
get_field(LoadConfig, "model_loader_extra_config")
|
get_field(LoadConfig, "model_loader_extra_config")
|
||||||
@@ -250,7 +252,7 @@ class EngineArgs:
|
|||||||
enable_sleep_mode: bool = False
|
enable_sleep_mode: bool = False
|
||||||
model_impl: str = "auto"
|
model_impl: str = "auto"
|
||||||
|
|
||||||
calculate_kv_scales: Optional[bool] = None
|
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||||
|
|
||||||
additional_config: Optional[Dict[str, Any]] = None
|
additional_config: Optional[Dict[str, Any]] = None
|
||||||
enable_reasoning: Optional[bool] = None
|
enable_reasoning: Optional[bool] = None
|
||||||
@@ -306,12 +308,19 @@ class EngineArgs:
|
|||||||
cls_docs = get_attr_docs(cls)
|
cls_docs = get_attr_docs(cls)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for field in fields(cls):
|
for field in fields(cls):
|
||||||
name = field.name
|
# Get the default value of the field
|
||||||
default = field.default
|
default = field.default
|
||||||
# This will only be True if default is MISSING
|
|
||||||
if field.default_factory is not MISSING:
|
if field.default_factory is not MISSING:
|
||||||
default = field.default_factory()
|
default = field.default_factory()
|
||||||
kwargs[name] = {"default": default, "help": cls_docs[name]}
|
|
||||||
|
# Get the help text for the field
|
||||||
|
name = field.name
|
||||||
|
help = cls_docs[name]
|
||||||
|
# Escape % for argparse
|
||||||
|
help = help.replace("%", "%%")
|
||||||
|
|
||||||
|
# Initialise the kwargs dictionary for the field
|
||||||
|
kwargs[name] = {"default": default, "help": help}
|
||||||
|
|
||||||
# Make note of if the field is optional and get the actual
|
# Make note of if the field is optional and get the actual
|
||||||
# type of the field if it is
|
# type of the field if it is
|
||||||
@@ -319,6 +328,8 @@ class EngineArgs:
|
|||||||
field_type = get_args(
|
field_type = get_args(
|
||||||
field.type)[0] if optional else field.type
|
field.type)[0] if optional else field.type
|
||||||
|
|
||||||
|
# Set type, action and choices for the field depending on the
|
||||||
|
# type of the field
|
||||||
if can_be_type(field_type, bool):
|
if can_be_type(field_type, bool):
|
||||||
# Creates --no-<name> and --<name> flags
|
# Creates --no-<name> and --<name> flags
|
||||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||||
@@ -463,14 +474,6 @@ class EngineArgs:
|
|||||||
'* "bfloat16" for a balance between precision and range.\n'
|
'* "bfloat16" for a balance between precision and range.\n'
|
||||||
'* "float" is shorthand for FP32 precision.\n'
|
'* "float" is shorthand for FP32 precision.\n'
|
||||||
'* "float32" for FP32 precision.')
|
'* "float32" for FP32 precision.')
|
||||||
parser.add_argument(
|
|
||||||
'--kv-cache-dtype',
|
|
||||||
type=str,
|
|
||||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
|
||||||
default=EngineArgs.kv_cache_dtype,
|
|
||||||
help='Data type for kv cache storage. If "auto", will use model '
|
|
||||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
|
||||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
|
||||||
parser.add_argument('--max-model-len',
|
parser.add_argument('--max-model-len',
|
||||||
type=human_readable_int,
|
type=human_readable_int,
|
||||||
default=EngineArgs.max_model_len,
|
default=EngineArgs.max_model_len,
|
||||||
@@ -544,33 +547,30 @@ class EngineArgs:
|
|||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
'--disable-custom-all-reduce',
|
'--disable-custom-all-reduce',
|
||||||
**parallel_kwargs["disable_custom_all_reduce"])
|
**parallel_kwargs["disable_custom_all_reduce"])
|
||||||
# KV cache arguments
|
|
||||||
parser.add_argument('--block-size',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.block_size,
|
|
||||||
choices=[8, 16, 32, 64, 128],
|
|
||||||
help='Token block size for contiguous chunks of '
|
|
||||||
'tokens. This is ignored on neuron devices and '
|
|
||||||
'set to ``--max-model-len``. On CUDA devices, '
|
|
||||||
'only block sizes up to 32 are supported. '
|
|
||||||
'On HPU devices, block size defaults to 128.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
# KV cache arguments
|
||||||
"--enable-prefix-caching",
|
cache_kwargs = get_kwargs(CacheConfig)
|
||||||
action=argparse.BooleanOptionalAction,
|
cache_group = parser.add_argument_group(
|
||||||
default=EngineArgs.enable_prefix_caching,
|
title="CacheConfig",
|
||||||
help="Enables automatic prefix caching. "
|
description=CacheConfig.__doc__,
|
||||||
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prefix-caching-hash-algo",
|
|
||||||
type=str,
|
|
||||||
choices=["builtin", "sha256"],
|
|
||||||
default=EngineArgs.prefix_caching_hash_algo,
|
|
||||||
help="Set the hash algorithm for prefix caching. "
|
|
||||||
"Options are 'builtin' (Python's built-in hash) or 'sha256' "
|
|
||||||
"(collision resistant but with certain overheads).",
|
|
||||||
)
|
)
|
||||||
|
cache_group.add_argument('--block-size', **cache_kwargs["block_size"])
|
||||||
|
cache_group.add_argument('--gpu-memory-utilization',
|
||||||
|
**cache_kwargs["gpu_memory_utilization"])
|
||||||
|
cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"])
|
||||||
|
cache_group.add_argument('--kv-cache-dtype',
|
||||||
|
**cache_kwargs["cache_dtype"])
|
||||||
|
cache_group.add_argument('--num-gpu-blocks-override',
|
||||||
|
**cache_kwargs["num_gpu_blocks_override"])
|
||||||
|
cache_group.add_argument("--enable-prefix-caching",
|
||||||
|
**cache_kwargs["enable_prefix_caching"])
|
||||||
|
cache_group.add_argument("--prefix-caching-hash-algo",
|
||||||
|
**cache_kwargs["prefix_caching_hash_algo"])
|
||||||
|
cache_group.add_argument('--cpu-offload-gb',
|
||||||
|
**cache_kwargs["cpu_offload_gb"])
|
||||||
|
cache_group.add_argument('--calculate-kv-scales',
|
||||||
|
**cache_kwargs["calculate_kv_scales"])
|
||||||
|
|
||||||
parser.add_argument('--disable-sliding-window',
|
parser.add_argument('--disable-sliding-window',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Disables sliding window, '
|
help='Disables sliding window, '
|
||||||
@@ -588,43 +588,6 @@ class EngineArgs:
|
|||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.seed,
|
default=EngineArgs.seed,
|
||||||
help='Random seed for operations.')
|
help='Random seed for operations.')
|
||||||
parser.add_argument('--swap-space',
|
|
||||||
type=float,
|
|
||||||
default=EngineArgs.swap_space,
|
|
||||||
help='CPU swap space size (GiB) per GPU.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--cpu-offload-gb',
|
|
||||||
type=float,
|
|
||||||
default=0,
|
|
||||||
help='The space in GiB to offload to CPU, per GPU. '
|
|
||||||
'Default is 0, which means no offloading. Intuitively, '
|
|
||||||
'this argument can be seen as a virtual way to increase '
|
|
||||||
'the GPU memory size. For example, if you have one 24 GB '
|
|
||||||
'GPU and set this to 10, virtually you can think of it as '
|
|
||||||
'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
|
|
||||||
'which requires at least 26GB GPU memory. Note that this '
|
|
||||||
'requires fast CPU-GPU interconnect, as part of the model is '
|
|
||||||
'loaded from CPU memory to GPU memory on the fly in each '
|
|
||||||
'model forward pass.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--gpu-memory-utilization',
|
|
||||||
type=float,
|
|
||||||
default=EngineArgs.gpu_memory_utilization,
|
|
||||||
help='The fraction of GPU memory to be used for the model '
|
|
||||||
'executor, which can range from 0 to 1. For example, a value of '
|
|
||||||
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
|
|
||||||
'will use the default value of 0.9. This is a per-instance '
|
|
||||||
'limit, and only applies to the current vLLM instance.'
|
|
||||||
'It does not matter if you have another vLLM instance running '
|
|
||||||
'on the same GPU. For example, if you have two vLLM instances '
|
|
||||||
'running on the same GPU, you can set the GPU memory utilization '
|
|
||||||
'to 0.5 for each instance.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--num-gpu-blocks-override',
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help='If specified, ignore GPU profiling result and use this number'
|
|
||||||
' of GPU blocks. Used for testing preemption.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--max-logprobs',
|
'--max-logprobs',
|
||||||
type=int,
|
type=int,
|
||||||
@@ -994,15 +957,6 @@ class EngineArgs:
|
|||||||
help="Enable sleep mode for the engine. "
|
help="Enable sleep mode for the engine. "
|
||||||
"(only cuda platform is supported)")
|
"(only cuda platform is supported)")
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--calculate-kv-scales',
|
|
||||||
action='store_true',
|
|
||||||
help='This enables dynamic calculation of '
|
|
||||||
'k_scale and v_scale when kv-cache-dtype is fp8. '
|
|
||||||
'If calculate-kv-scales is false, the scales will '
|
|
||||||
'be loaded from the model checkpoint if available. '
|
|
||||||
'Otherwise, the scales will default to 1.0.')
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--additional-config",
|
"--additional-config",
|
||||||
type=json.loads,
|
type=json.loads,
|
||||||
@@ -1625,9 +1579,7 @@ class EngineArgs:
|
|||||||
self.enable_prefix_caching = False
|
self.enable_prefix_caching = False
|
||||||
|
|
||||||
# VLLM_V0 only supports builtin hash algo for prefix caching.
|
# VLLM_V0 only supports builtin hash algo for prefix caching.
|
||||||
if self.prefix_caching_hash_algo is None:
|
if self.prefix_caching_hash_algo == "sha256":
|
||||||
self.prefix_caching_hash_algo = "builtin"
|
|
||||||
elif self.prefix_caching_hash_algo == "sha256":
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"sha256 is not supported for prefix caching in V0 engine. "
|
"sha256 is not supported for prefix caching in V0 engine. "
|
||||||
"Please use 'builtin'.")
|
"Please use 'builtin'.")
|
||||||
@@ -1646,10 +1598,6 @@ class EngineArgs:
|
|||||||
if self.enable_prefix_caching is None:
|
if self.enable_prefix_caching is None:
|
||||||
self.enable_prefix_caching = True
|
self.enable_prefix_caching = True
|
||||||
|
|
||||||
# if using prefix caching, we must set a hash algo
|
|
||||||
if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
|
|
||||||
self.prefix_caching_hash_algo = "builtin"
|
|
||||||
|
|
||||||
# V1 should use the new scheduler by default.
|
# V1 should use the new scheduler by default.
|
||||||
# Swap it only if this arg is set to the original V0 default
|
# Swap it only if this arg is set to the original V0 default
|
||||||
if self.scheduler_cls == EngineArgs.scheduler_cls:
|
if self.scheduler_cls == EngineArgs.scheduler_cls:
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class NeuronPlatform(Platform):
|
|||||||
if cache_config:
|
if cache_config:
|
||||||
# neuron needs block_size = max_model_len
|
# neuron needs block_size = max_model_len
|
||||||
vllm_config.cache_config.block_size = \
|
vllm_config.cache_config.block_size = \
|
||||||
vllm_config.model_config.max_model_len
|
vllm_config.model_config.max_model_len # type: ignore
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_pin_memory_available(cls) -> bool:
|
def is_pin_memory_available(cls) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user