[Kernel] Flash Attention 3 Support (#12093)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-01-23 09:45:48 -05:00
committed by GitHub
parent c5b4b11d7f
commit 978b45f399
8 changed files with 151 additions and 83 deletions

View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
@@ -90,6 +91,12 @@ def get_default_config_root():
)
def maybe_convert_int(value: Optional[str]) -> Optional[int]:
if value is None:
return None
return int(value)
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
@@ -203,6 +210,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),
# Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION":
lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(