[Kernel] Flash Attention 3 Support (#12093)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
12
vllm/envs.py
12
vllm/envs.py
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
||||
VLLM_NCCL_SO_PATH: Optional[str] = None
|
||||
LD_LIBRARY_PATH: Optional[str] = None
|
||||
VLLM_USE_TRITON_FLASH_ATTN: bool = False
|
||||
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
||||
LOCAL_RANK: int = 0
|
||||
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
||||
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
|
||||
@@ -90,6 +91,12 @@ def get_default_config_root():
|
||||
)
|
||||
|
||||
|
||||
def maybe_convert_int(value: Optional[str]) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the used env vars.
|
||||
|
||||
@@ -203,6 +210,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
||||
# when using the flash-attention backend.
|
||||
"VLLM_FLASH_ATTN_VERSION":
|
||||
lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)),
|
||||
|
||||
# Internal flag to enable Dynamo fullgraph capture
|
||||
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
|
||||
lambda: bool(
|
||||
|
||||
Reference in New Issue
Block a user