[Attention][UX][1/N] Add AttentionConfig and change attention env vars to CLI arguments (#26315)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool:
|
||||
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
|
||||
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||
return env_value
|
||||
|
||||
|
||||
def force_use_trtllm_attention() -> bool | None:
|
||||
"""
|
||||
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
||||
Return `None` if --attention-config.use_trtllm_attention is not set,
|
||||
return `True` if TRTLLM attention is forced to be used,
|
||||
return `False` if TRTLLM attention is forced to be not used.
|
||||
"""
|
||||
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return vllm_config.attention_config.use_trtllm_attention
|
||||
|
||||
|
||||
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
||||
@@ -307,7 +302,7 @@ def use_trtllm_attention(
|
||||
"""Return `True` if TRTLLM attention is used."""
|
||||
force_use_trtllm = force_use_trtllm_attention()
|
||||
|
||||
# Environment variable is set to 0 - respect it
|
||||
# CLI argument is set to 0 - respect it
|
||||
if force_use_trtllm is not None and not force_use_trtllm:
|
||||
return False
|
||||
|
||||
@@ -324,7 +319,7 @@ def use_trtllm_attention(
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported on this platform, "
|
||||
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||
"but --attention-config.use_trtllm_attention is set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -333,7 +328,8 @@ def use_trtllm_attention(
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported for this combination of "
|
||||
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||
"query and key heads, but --attention-config.use_trtllm_attention is "
|
||||
"set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -354,7 +350,7 @@ def use_trtllm_attention(
|
||||
return True
|
||||
|
||||
if force_use_trtllm is None:
|
||||
# Environment variable not set - use auto-detection
|
||||
# CLI argument not set - use auto-detection
|
||||
if is_prefill:
|
||||
# Prefill auto-detection
|
||||
use_trtllm = kv_cache_dtype == "auto"
|
||||
@@ -367,8 +363,10 @@ def use_trtllm_attention(
|
||||
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
# Environment variable is set to 1 - respect it
|
||||
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
|
||||
# CLI argument is set to 1 - respect it
|
||||
logger.info_once(
|
||||
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm(
|
||||
return output
|
||||
|
||||
|
||||
@functools.cache
|
||||
def flashinfer_disable_q_quantization() -> bool:
|
||||
"""Cache result which only depends on the environment"""
|
||||
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
@@ -526,7 +518,6 @@ __all__ = [
|
||||
"supports_trtllm_attention",
|
||||
"can_use_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_disable_q_quantization",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
"flashinfer_scaled_fp8_mm",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user