[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-12-05 12:48:43 -05:00
committed by GitHub
parent dff0a2b394
commit 66e674cdd5
22 changed files with 367 additions and 325 deletions

View File

@@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool:
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
@functools.cache
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
return env_value
def force_use_trtllm_attention() -> bool | None:
"""
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
Return `None` if --attention-config.use_trtllm_attention is not set,
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
"""
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
return vllm_config.attention_config.use_trtllm_attention
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
@@ -307,7 +302,7 @@ def use_trtllm_attention(
"""Return `True` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention()
# Environment variable is set to 0 - respect it
# CLI argument is set to 0 - respect it
if force_use_trtllm is not None and not force_use_trtllm:
return False
@@ -324,7 +319,7 @@ def use_trtllm_attention(
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported on this platform, "
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
"but --attention-config.use_trtllm_attention is set to 1"
)
return False
@@ -333,7 +328,8 @@ def use_trtllm_attention(
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported for this combination of "
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
"query and key heads, but --attention-config.use_trtllm_attention is "
"set to 1"
)
return False
@@ -354,7 +350,7 @@ def use_trtllm_attention(
return True
if force_use_trtllm is None:
# Environment variable not set - use auto-detection
# CLI argument not set - use auto-detection
if is_prefill:
# Prefill auto-detection
use_trtllm = kv_cache_dtype == "auto"
@@ -367,8 +363,10 @@ def use_trtllm_attention(
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm
# Environment variable is set to 1 - respect it
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
# CLI argument is set to 1 - respect it
logger.info_once(
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
)
return True
@@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm(
return output
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
@@ -526,7 +518,6 @@ __all__ = [
"supports_trtllm_attention",
"can_use_trtllm_attention",
"use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
]