[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config.attention import AttentionConfig
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.compilation import (
|
||||
CompilationConfig,
|
||||
@@ -46,6 +47,8 @@ from vllm.config.vllm import (
|
||||
# __all__ should only contain classes and functions.
|
||||
# Types and globals should be imported from their respective modules.
|
||||
__all__ = [
|
||||
# From vllm.config.attention
|
||||
"AttentionConfig",
|
||||
# From vllm.config.cache
|
||||
"CacheConfig",
|
||||
# From vllm.config.compilation
|
||||
|
||||
114
vllm/config/attention.py
Normal file
114
vllm/config/attention.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class AttentionConfig:
|
||||
"""Configuration for attention mechanisms in vLLM."""
|
||||
|
||||
backend: AttentionBackendEnum | None = None
|
||||
"""Attention backend to use. If None, will be selected automatically."""
|
||||
|
||||
flash_attn_version: Literal[2, 3] | None = None
|
||||
"""Force vllm to use a specific flash-attention version (2 or 3).
|
||||
Only valid when using the flash-attention backend."""
|
||||
|
||||
use_prefill_decode_attention: bool = False
|
||||
"""Use separate prefill and decode kernels for attention instead of
|
||||
the unified triton kernel."""
|
||||
|
||||
flash_attn_max_num_splits_for_cuda_graph: int = 32
|
||||
"""Flash Attention max number splits for cuda graph decode."""
|
||||
|
||||
use_cudnn_prefill: bool = False
|
||||
"""Whether to use cudnn prefill."""
|
||||
|
||||
use_trtllm_ragged_deepseek_prefill: bool = False
|
||||
"""Whether to use TRTLLM ragged deepseek prefill."""
|
||||
|
||||
use_trtllm_attention: bool | None = None
|
||||
"""If set to True/False, use or don't use the TRTLLM attention backend
|
||||
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
|
||||
|
||||
disable_flashinfer_prefill: bool = False
|
||||
"""Whether to disable flashinfer prefill."""
|
||||
|
||||
disable_flashinfer_q_quantization: bool = False
|
||||
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
ignored_factors: list[str] = []
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
@field_validator("backend", mode="before")
|
||||
@classmethod
|
||||
def validate_backend_before(cls, value: Any) -> Any:
|
||||
"""Enable parsing of the `backend` enum type from string."""
|
||||
if isinstance(value, str):
|
||||
return AttentionBackendEnum[value.upper()]
|
||||
return value
|
||||
|
||||
def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
|
||||
"""Set field from env var if set, with deprecation warning."""
|
||||
from vllm import envs
|
||||
|
||||
if envs.is_set(env_var_name):
|
||||
value = getattr(envs, env_var_name)
|
||||
if field_name == "backend":
|
||||
value = self.validate_backend_before(value)
|
||||
setattr(self, field_name, value)
|
||||
logger.warning_once(
|
||||
"Using %s environment variable is deprecated and will be removed in "
|
||||
"v0.14.0 or v1.0.0, whichever is soonest. Please use "
|
||||
"--attention-config.%s command line argument or "
|
||||
"AttentionConfig(%s=...) config field instead.",
|
||||
env_var_name,
|
||||
field_name,
|
||||
field_name,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._set_from_env_if_set("backend", "VLLM_ATTENTION_BACKEND")
|
||||
self._set_from_env_if_set("flash_attn_version", "VLLM_FLASH_ATTN_VERSION")
|
||||
self._set_from_env_if_set(
|
||||
"use_prefill_decode_attention", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION"
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"flash_attn_max_num_splits_for_cuda_graph",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||
)
|
||||
self._set_from_env_if_set("use_cudnn_prefill", "VLLM_USE_CUDNN_PREFILL")
|
||||
self._set_from_env_if_set(
|
||||
"use_trtllm_ragged_deepseek_prefill",
|
||||
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
|
||||
)
|
||||
self._set_from_env_if_set("use_trtllm_attention", "VLLM_USE_TRTLLM_ATTENTION")
|
||||
self._set_from_env_if_set(
|
||||
"disable_flashinfer_prefill", "VLLM_DISABLE_FLASHINFER_PREFILL"
|
||||
)
|
||||
self._set_from_env_if_set(
|
||||
"disable_flashinfer_q_quantization",
|
||||
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
|
||||
)
|
||||
@@ -4,7 +4,6 @@
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar, field
|
||||
from importlib.util import find_spec
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
||||
|
||||
import torch
|
||||
@@ -467,18 +466,6 @@ class ModelConfig:
|
||||
|
||||
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
|
||||
|
||||
if (
|
||||
(backend := envs.VLLM_ATTENTION_BACKEND)
|
||||
and backend == "FLASHINFER"
|
||||
and find_spec("flashinfer") is None
|
||||
):
|
||||
raise ValueError(
|
||||
"VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer "
|
||||
"module was not found. See "
|
||||
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
|
||||
"for instructions on how to install it."
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if self.override_attention_dtype is not None and not current_platform.is_rocm():
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
from .attention import AttentionConfig
|
||||
from .cache import CacheConfig
|
||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
from .device import DeviceConfig
|
||||
@@ -192,6 +193,8 @@ class VllmConfig:
|
||||
"""Device configuration."""
|
||||
load_config: LoadConfig = Field(default_factory=LoadConfig)
|
||||
"""Load configuration."""
|
||||
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
|
||||
"""Attention configuration."""
|
||||
lora_config: LoRAConfig | None = None
|
||||
"""LoRA configuration."""
|
||||
speculative_config: SpeculativeConfig | None = None
|
||||
@@ -279,6 +282,10 @@ class VllmConfig:
|
||||
vllm_factors.append(self.load_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.attention_config:
|
||||
vllm_factors.append(self.attention_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.lora_config:
|
||||
vllm_factors.append(self.lora_config.compute_hash())
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user