[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from typing_extensions import TypeIs
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
ConfigType,
|
||||
@@ -527,6 +528,7 @@ class EngineArgs:
|
||||
|
||||
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
||||
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
|
||||
attention_config: AttentionConfig = get_field(VllmConfig, "attention_config")
|
||||
worker_cls: str = ParallelConfig.worker_cls
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
@@ -542,6 +544,7 @@ class EngineArgs:
|
||||
)
|
||||
model_impl: str = ModelConfig.model_impl
|
||||
override_attention_dtype: str = ModelConfig.override_attention_dtype
|
||||
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
|
||||
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||
@@ -580,6 +583,8 @@ class EngineArgs:
|
||||
# CompilationConfig object
|
||||
if isinstance(self.compilation_config, dict):
|
||||
self.compilation_config = CompilationConfig(**self.compilation_config)
|
||||
if isinstance(self.attention_config, dict):
|
||||
self.attention_config = AttentionConfig(**self.attention_config)
|
||||
if isinstance(self.eplb_config, dict):
|
||||
self.eplb_config = EPLBConfig(**self.eplb_config)
|
||||
# Setup plugins
|
||||
@@ -717,6 +722,16 @@ class EngineArgs:
|
||||
"--pt-load-map-location", **load_kwargs["pt_load_map_location"]
|
||||
)
|
||||
|
||||
# Attention arguments
|
||||
attention_kwargs = get_kwargs(AttentionConfig)
|
||||
attention_group = parser.add_argument_group(
|
||||
title="AttentionConfig",
|
||||
description=AttentionConfig.__doc__,
|
||||
)
|
||||
attention_group.add_argument(
|
||||
"--attention-backend", **attention_kwargs["backend"]
|
||||
)
|
||||
|
||||
# Structured outputs arguments
|
||||
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
|
||||
structured_outputs_group = parser.add_argument_group(
|
||||
@@ -1140,6 +1155,9 @@ class EngineArgs:
|
||||
vllm_group.add_argument(
|
||||
"--compilation-config", "-cc", **vllm_kwargs["compilation_config"]
|
||||
)
|
||||
vllm_group.add_argument(
|
||||
"--attention-config", "-ac", **vllm_kwargs["attention_config"]
|
||||
)
|
||||
vllm_group.add_argument(
|
||||
"--additional-config", **vllm_kwargs["additional_config"]
|
||||
)
|
||||
@@ -1693,6 +1711,16 @@ class EngineArgs:
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
self.quantization = self.load_format = "bitsandbytes"
|
||||
|
||||
# Attention config overrides
|
||||
attention_config = copy.deepcopy(self.attention_config)
|
||||
if self.attention_backend is not None:
|
||||
if attention_config.backend is not None:
|
||||
raise ValueError(
|
||||
"attention_backend and attention_config.backend "
|
||||
"are mutually exclusive"
|
||||
)
|
||||
attention_config.backend = self.attention_backend
|
||||
|
||||
load_config = self.create_load_config()
|
||||
|
||||
# Pass reasoning_parser into StructuredOutputsConfig
|
||||
@@ -1750,9 +1778,10 @@ class EngineArgs:
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
load_config=load_config,
|
||||
attention_config=attention_config,
|
||||
lora_config=lora_config,
|
||||
speculative_config=speculative_config,
|
||||
load_config=load_config,
|
||||
structured_outputs_config=self.structured_outputs_config,
|
||||
observability_config=observability_config,
|
||||
compilation_config=compilation_config,
|
||||
|
||||
Reference in New Issue
Block a user