[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-12-05 12:48:43 -05:00
committed by GitHub
parent dff0a2b394
commit 66e674cdd5
22 changed files with 367 additions and 325 deletions

View File

@@ -34,6 +34,7 @@ from typing_extensions import TypeIs
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
AttentionConfig,
CacheConfig,
CompilationConfig,
ConfigType,
@@ -527,6 +528,7 @@ class EngineArgs:
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls
@@ -542,6 +544,7 @@ class EngineArgs:
)
model_impl: str = ModelConfig.model_impl
override_attention_dtype: str = ModelConfig.override_attention_dtype
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
@@ -580,6 +583,8 @@ class EngineArgs:
# CompilationConfig object
if isinstance(self.compilation_config, dict):
self.compilation_config = CompilationConfig(**self.compilation_config)
if isinstance(self.attention_config, dict):
self.attention_config = AttentionConfig(**self.attention_config)
if isinstance(self.eplb_config, dict):
self.eplb_config = EPLBConfig(**self.eplb_config)
# Setup plugins
@@ -717,6 +722,16 @@ class EngineArgs:
"--pt-load-map-location", **load_kwargs["pt_load_map_location"]
)
# Attention arguments
attention_kwargs = get_kwargs(AttentionConfig)
attention_group = parser.add_argument_group(
title="AttentionConfig",
description=AttentionConfig.__doc__,
)
attention_group.add_argument(
"--attention-backend", **attention_kwargs["backend"]
)
# Structured outputs arguments
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
structured_outputs_group = parser.add_argument_group(
@@ -1140,6 +1155,9 @@ class EngineArgs:
vllm_group.add_argument(
"--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
)
vllm_group.add_argument(
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
)
vllm_group.add_argument(
"--additional-config", **vllm_kwargs["additional_config"]
)
@@ -1693,6 +1711,16 @@ class EngineArgs:
if model_config.quantization == "bitsandbytes":
self.quantization = self.load_format = "bitsandbytes"
# Attention config overrides
attention_config = copy.deepcopy(self.attention_config)
if self.attention_backend is not None:
if attention_config.backend is not None:
raise ValueError(
"attention_backend and attention_config.backend "
"are mutually exclusive"
)
attention_config.backend = self.attention_backend
load_config = self.create_load_config()
# Pass reasoning_parser into StructuredOutputsConfig
@@ -1750,9 +1778,10 @@ class EngineArgs:
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
load_config=load_config,
attention_config=attention_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
structured_outputs_config=self.structured_outputs_config,
observability_config=observability_config,
compilation_config=compilation_config,