[Attention] FA4 integration (#32974)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -9,6 +9,7 @@ steps:
|
||||
- tests/v1
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# split the test to avoid interference
|
||||
- pytest -v -s -m 'not cpu_test' v1/core
|
||||
- pytest -v -s v1/executor
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -3,6 +3,8 @@
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/__init__.py
|
||||
!vllm/vllm_flash_attn/flash_attn_interface.py
|
||||
|
||||
# OpenAI triton kernels copied from source
|
||||
vllm/third_party/triton_kernels/*
|
||||
|
||||
@@ -17,7 +17,8 @@ endif()
|
||||
# They should be identical but if they aren't, this is a massive footgun.
|
||||
#
|
||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3),
|
||||
# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files).
|
||||
# If no component is specified, vllm-flash-attn is still installed.
|
||||
|
||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||
@@ -38,7 +39,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6
|
||||
GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
@@ -46,38 +47,62 @@ else()
|
||||
endif()
|
||||
|
||||
|
||||
# Ensure the vllm/vllm_flash_attn directory exists before installation
|
||||
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS)
|
||||
|
||||
# Make sure vllm-flash-attn install rules are nested under vllm/
|
||||
# This is here to support installing all components under the same prefix with cmake --install.
|
||||
# setup.py installs every component separately but uses the same prefix for all.
|
||||
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3,
|
||||
# and these statements don't hurt when installing neither component.
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
|
||||
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
|
||||
# Install rules for FA components need the install prefix nested under vllm/
|
||||
# These run at install time, before the FA library's own install rules
|
||||
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT ${_FA_COMPONENT})
|
||||
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT ${_FA_COMPONENT})
|
||||
endforeach()
|
||||
|
||||
# Fetch the vllm-flash-attn library
|
||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||
|
||||
# Restore the install prefix
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||
# Restore the install prefix after FA's install rules
|
||||
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT ${_FA_COMPONENT})
|
||||
endforeach()
|
||||
|
||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||
# case only one is built, in the case both are built redundant work is done)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm/vllm_flash_attn
|
||||
COMPONENT _vllm_fa2_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
# Install shared Python files for both FA2 and FA3 components
|
||||
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
|
||||
# Ensure the vllm/vllm_flash_attn directory exists before installation
|
||||
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")"
|
||||
COMPONENT ${_FA_COMPONENT})
|
||||
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm/vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
# Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py
|
||||
# which are source-controlled in vllm)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm/vllm_flash_attn
|
||||
COMPONENT ${_FA_COMPONENT}
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
PATTERN "__init__.py" EXCLUDE
|
||||
PATTERN "flash_attn_interface.py" EXCLUDE
|
||||
)
|
||||
|
||||
endforeach()
|
||||
|
||||
#
|
||||
# FA4 CuteDSL component
|
||||
# This is a Python-only component that copies the flash_attn/cute directory
|
||||
# and transforms imports to match our package structure.
|
||||
#
|
||||
add_custom_target(_vllm_fa4_cutedsl_C)
|
||||
|
||||
# Copy flash_attn/cute directory (needed for FA4) and transform imports
|
||||
# The cute directory uses flash_attn.cute imports internally, which we replace
|
||||
# with vllm.vllm_flash_attn.cute to match our package structure.
|
||||
install(CODE "
|
||||
file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
|
||||
foreach(SRC_FILE \${CUTE_PY_FILES})
|
||||
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
|
||||
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
|
||||
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
|
||||
file(MAKE_DIRECTORY \${DST_DIR})
|
||||
file(READ \${SRC_FILE} FILE_CONTENTS)
|
||||
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
|
||||
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
|
||||
endforeach()
|
||||
" COMPONENT _vllm_fa4_cutedsl_C)
|
||||
|
||||
@@ -168,6 +168,7 @@ Priority is **1 = highest** (tried first).
|
||||
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
|
||||
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
|
||||
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
|
||||
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
|
||||
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
||||
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
||||
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
@@ -178,7 +179,7 @@ Priority is **1 = highest** (tried first).
|
||||
|
||||
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
|
||||
>
|
||||
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, FA2 otherwise.
|
||||
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2`, `3`, or `4`. Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), FA2 otherwise.
|
||||
|
||||
## MLA (Multi-head Latent Attention) Backends
|
||||
|
||||
|
||||
@@ -11,3 +11,7 @@ torchaudio==2.10.0
|
||||
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
flashinfer-python==0.6.4
|
||||
|
||||
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
|
||||
nvidia-cutlass-dsl>=4.4.0.dev1
|
||||
quack-kernels>=0.2.7
|
||||
|
||||
5
setup.py
5
setup.py
@@ -976,6 +976,11 @@ if _is_cuda():
|
||||
):
|
||||
# FA3 requires CUDA 12.3 or later
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
# FA4 CuteDSL - Python-only component for FA4's cute DSL support
|
||||
# Optional since this doesn't produce a .so file, just copies Python files
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa4_cutedsl_C", optional=True)
|
||||
)
|
||||
if envs.VLLM_USE_PRECOMPILED or (
|
||||
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9")
|
||||
):
|
||||
|
||||
@@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill)
|
||||
# Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences.
|
||||
def _parse_fa4_supported_caps() -> str | None:
|
||||
"""Parse flash_attn_interface.py for FA4 supported compute capabilities.
|
||||
|
||||
Returns a dict with 'fa2' and 'fa3' keys containing their respective
|
||||
Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported().
|
||||
"""
|
||||
fa_interface_file = (
|
||||
REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py"
|
||||
)
|
||||
if not fa_interface_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
tree = ast.parse(fa_interface_file.read_text())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported":
|
||||
continue
|
||||
for n in ast.walk(node):
|
||||
if not (
|
||||
isinstance(n, ast.Compare)
|
||||
and len(n.ops) == 1
|
||||
and isinstance(n.ops[0], ast.NotIn)
|
||||
and isinstance(n.comparators[0], ast.List)
|
||||
):
|
||||
continue
|
||||
caps: list[int] = [
|
||||
e.value
|
||||
for e in n.comparators[0].elts
|
||||
if isinstance(e, ast.Constant) and isinstance(e.value, int)
|
||||
]
|
||||
if caps:
|
||||
caps.sort()
|
||||
return f"{caps[0]}.x-{caps[-1]}.x"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
"""Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences.
|
||||
|
||||
Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective
|
||||
feature overrides for compute capability, KV cache dtypes, and sink support.
|
||||
"""
|
||||
if not FA_UTILS_FILE.exists():
|
||||
@@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
fa3_supports_fp8 = False
|
||||
fa3_supports_sinks = False
|
||||
fa3_compute_cap: str | None = None
|
||||
fa4_compute_cap: str | None = None
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef):
|
||||
@@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
fa3_supports_sinks = True
|
||||
break
|
||||
|
||||
# Check get_flash_attn_version for FA3 compute capability
|
||||
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
|
||||
# Check get_flash_attn_version for FA3/FA4 compute capability
|
||||
if node.name == "get_flash_attn_version":
|
||||
for n in ast.walk(node):
|
||||
# Look for IfExp (ternary) with `device_capability.major == 9`
|
||||
# Handle IfExp (ternary) with `device_capability.major == 9`
|
||||
if isinstance(n, ast.IfExp):
|
||||
test = n.test
|
||||
# Check if test is a BoolOp (and) containing the major check
|
||||
if isinstance(test, ast.BoolOp):
|
||||
for val in test.values:
|
||||
if (
|
||||
@@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
fa3_compute_cap = f"{val.comparators[0].value}.x"
|
||||
break
|
||||
|
||||
# Handle If statements for FA3/FA4 detection
|
||||
# e.g. `if device_capability.major == 9` -> FA3
|
||||
# `elif device_capability.major >= 10` -> FA4
|
||||
if isinstance(n, ast.If):
|
||||
test = n.test
|
||||
comparisons = (
|
||||
[v for v in test.values if isinstance(v, ast.Compare)]
|
||||
if isinstance(test, ast.BoolOp)
|
||||
else [test]
|
||||
if isinstance(test, ast.Compare)
|
||||
else []
|
||||
)
|
||||
for comp in comparisons:
|
||||
if not (
|
||||
isinstance(comp.left, ast.Attribute)
|
||||
and comp.left.attr == "major"
|
||||
and comp.comparators
|
||||
and isinstance(comp.comparators[0], ast.Constant)
|
||||
and isinstance(comp.comparators[0].value, int)
|
||||
):
|
||||
continue
|
||||
op = comp.ops[0]
|
||||
val = comp.comparators[0].value
|
||||
if isinstance(op, ast.Eq) and fa3_compute_cap is None:
|
||||
fa3_compute_cap = f"{val}.x"
|
||||
elif isinstance(op, ast.GtE) and fa4_compute_cap is None:
|
||||
fa4_compute_cap = f"≥{val}.0"
|
||||
|
||||
# Fallback: try to parse FA4 compute caps from flash_attn_interface.py
|
||||
if fa4_compute_cap is None:
|
||||
fa4_compute_cap = _parse_fa4_supported_caps()
|
||||
|
||||
return {
|
||||
"fa2": {
|
||||
"supports_fp8": False,
|
||||
@@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
|
||||
"supports_fp8": fa3_supports_fp8,
|
||||
"supports_sink": fa3_supports_sinks,
|
||||
},
|
||||
"fa4": {
|
||||
"compute_capability": fa4_compute_cap,
|
||||
"supports_fp8": False,
|
||||
"supports_sink": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM)
|
||||
# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -768,7 +843,7 @@ def _expand_flash_attn_variants(
|
||||
all_backends: list[dict[str, Any]],
|
||||
fa_features: dict[str, dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities."""
|
||||
"""Expand FLASH_ATTN into FA2, FA3, and FA4 variants."""
|
||||
expanded = []
|
||||
for backend in all_backends:
|
||||
if backend["name"] != "FLASH_ATTN":
|
||||
@@ -801,6 +876,18 @@ def _expand_flash_attn_variants(
|
||||
|
||||
expanded.append(fa2)
|
||||
expanded.append(fa3)
|
||||
|
||||
# Create FA4 entry if FA4 features are available
|
||||
if "fa4" in fa_features:
|
||||
fa4 = backend.copy()
|
||||
fa4["version"] = "FA4*"
|
||||
fa4["_sort_key"] = "FLASH_ATTN"
|
||||
fa4["_sort_order"] = 2
|
||||
if fa_features["fa4"].get("compute_capability"):
|
||||
fa4["compute_capability"] = fa_features["fa4"]["compute_capability"]
|
||||
fa4["supports_sink"] = fa_features["fa4"]["supports_sink"]
|
||||
expanded.append(fa4)
|
||||
|
||||
return expanded
|
||||
|
||||
|
||||
@@ -1360,7 +1447,8 @@ def generate_docs() -> str:
|
||||
if fa_features:
|
||||
footnotes.append(
|
||||
"> **\\*** Specify the FlashAttention version via "
|
||||
"`--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, "
|
||||
"`--attention-config.flash_attn_version=2`, `3`, or `4`. "
|
||||
"Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), "
|
||||
"FA2 otherwise."
|
||||
)
|
||||
if footnotes:
|
||||
|
||||
@@ -16,8 +16,8 @@ class AttentionConfig:
|
||||
backend: AttentionBackendEnum | None = None
|
||||
"""Attention backend to use. If None, will be selected automatically."""
|
||||
|
||||
flash_attn_version: Literal[2, 3] | None = None
|
||||
"""Force vllm to use a specific flash-attention version (2 or 3).
|
||||
flash_attn_version: Literal[2, 3, 4] | None = None
|
||||
"""Force vllm to use a specific flash-attention version (2, 3, or 4).
|
||||
Only valid when using the flash-attention backend."""
|
||||
|
||||
use_prefill_decode_attention: bool = False
|
||||
|
||||
@@ -2014,7 +2014,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# RoCM and the latter has an additional parameter to control
|
||||
# FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
head_size=self.qk_head_dim
|
||||
)
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = functools.partial(
|
||||
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
|
||||
|
||||
@@ -204,7 +204,9 @@ class MMEncoderAttention(CustomOp):
|
||||
}
|
||||
|
||||
self._fa_version = (
|
||||
get_flash_attn_version() if self.is_flash_attn_backend else None
|
||||
get_flash_attn_version(head_size=head_size)
|
||||
if self.is_flash_attn_backend
|
||||
else None
|
||||
)
|
||||
|
||||
if self.attn_backend == AttentionBackendEnum.FLASHINFER:
|
||||
|
||||
@@ -52,7 +52,9 @@ elif current_platform.is_rocm():
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
def get_flash_attn_version(
|
||||
requires_alibi: bool = False, head_size: int | None = None
|
||||
) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = (
|
||||
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
|
||||
)
|
||||
if device_capability.major == 9 and is_fa_version_supported(3):
|
||||
# Hopper (SM90): prefer FA3
|
||||
fa_version = 3
|
||||
elif device_capability.major == 10 and is_fa_version_supported(4):
|
||||
# Blackwell (SM100+, restrict to SM100 for now): prefer FA4
|
||||
fa_version = 4
|
||||
else:
|
||||
# Fallback to FA2
|
||||
fa_version = 2
|
||||
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
@@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
if device_capability.major >= 10 and fa_version == 3:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 3 on Blackwell platform, "
|
||||
"defaulting to FA version 2."
|
||||
"defaulting to FA version 4 if supported, otherwise FA2."
|
||||
)
|
||||
fa_version = 2
|
||||
fa_version = 4 if is_fa_version_supported(4) else 2
|
||||
|
||||
if requires_alibi and fa_version == 3:
|
||||
logger.warning_once(
|
||||
@@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if requires_alibi and fa_version == 4:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
|
||||
# supported head dimensions.
|
||||
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
|
||||
if (
|
||||
fa_version == 4
|
||||
and device_capability.major >= 10
|
||||
and head_size is not None
|
||||
and head_size > 128
|
||||
):
|
||||
logger.warning_once(
|
||||
"FA4 on Blackwell does not support head_size=%d due to TMEM "
|
||||
"capacity limits, defaulting to FA version 2.",
|
||||
head_size,
|
||||
)
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error(
|
||||
"Cannot use FA version %d is not supported due to %s",
|
||||
@@ -139,6 +169,10 @@ def flash_attn_supports_mla():
|
||||
return is_fa_version_supported(
|
||||
3
|
||||
) and current_platform.is_device_capability_family(90)
|
||||
|
||||
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
|
||||
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
@@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
requires_alibi=alibi_slopes is not None,
|
||||
head_size=head_size,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using FlashAttention version %s",
|
||||
self.vllm_flash_attn_version,
|
||||
scope="local",
|
||||
)
|
||||
# Cache the batch invariant result for use in forward passes
|
||||
self.batch_invariant_enabled = vllm_is_batch_invariant()
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ class CudagraphDispatcher:
|
||||
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
|
||||
|
||||
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
|
||||
num_reqs = num_tokens_padded // uniform_decode_query_len
|
||||
num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs)
|
||||
assert num_tokens_padded % uniform_decode_query_len == 0
|
||||
else:
|
||||
uniform_decode = False
|
||||
|
||||
24
vllm/vllm_flash_attn/__init__.py
Normal file
24
vllm/vllm_flash_attn/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
FA2_AVAILABLE,
|
||||
FA3_AVAILABLE,
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
if not (FA2_AVAILABLE or FA3_AVAILABLE):
|
||||
raise ImportError(
|
||||
"vllm.vllm_flash_attn requires the CUDA flash attention extensions "
|
||||
"(_vllm_fa2_C or _vllm_fa3_C). On ROCm, use upstream flash_attn."
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"fa_version_unsupported_reason",
|
||||
"flash_attn_varlen_func",
|
||||
"get_scheduler_metadata",
|
||||
"is_fa_version_supported",
|
||||
]
|
||||
567
vllm/vllm_flash_attn/flash_attn_interface.py
Normal file
567
vllm/vllm_flash_attn/flash_attn_interface.py
Normal file
@@ -0,0 +1,567 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# ruff: noqa: E501
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
# isort: off
|
||||
# We need to import the CUDA kernels after importing torch
|
||||
# Use relative import to support build-from-source installation in vLLM
|
||||
|
||||
try:
|
||||
from . import _vllm_fa2_C # type: ignore[attr-defined] # noqa: F401
|
||||
|
||||
FA2_UNAVAILABLE_REASON = None
|
||||
FA2_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
FA2_UNAVAILABLE_REASON = str(e)
|
||||
FA2_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from . import _vllm_fa3_C # type: ignore[attr-defined] # noqa: F401
|
||||
|
||||
FA3_UNAVAILABLE_REASON = None
|
||||
FA3_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
FA3_UNAVAILABLE_REASON = str(e)
|
||||
FA3_AVAILABLE = False
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
_cute_interface_path = os.path.join(
|
||||
os.path.dirname(__file__), "cute", "interface.py"
|
||||
)
|
||||
if not os.path.exists(_cute_interface_path):
|
||||
raise ImportError("vllm.vllm_flash_attn.cute.interface not found")
|
||||
|
||||
FA4_UNAVAILABLE_REASON = None
|
||||
FA4_AVAILABLE = True
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
FA4_UNAVAILABLE_REASON = str(e)
|
||||
FA4_AVAILABLE = False
|
||||
|
||||
# isort: on
|
||||
|
||||
DEFAULT_FA_VERSION = 2
|
||||
|
||||
|
||||
def _is_fa2_supported() -> tuple[bool, str | None]:
|
||||
if not FA2_AVAILABLE:
|
||||
return False, f"FA2 is unavailable due to: {FA2_UNAVAILABLE_REASON}"
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(80):
|
||||
return False, "FA2 is only supported on devices with compute capability >= 8"
|
||||
return True, None
|
||||
|
||||
|
||||
def _is_fa3_supported() -> tuple[bool, str | None]:
|
||||
if not FA3_AVAILABLE:
|
||||
return False, f"FA3 is unavailable due to: {FA3_UNAVAILABLE_REASON}"
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_device_capability_family(90):
|
||||
return False, "FA3 is only supported on devices with compute capability 9.x"
|
||||
return True, None
|
||||
|
||||
|
||||
def _is_fa4_supported() -> tuple[bool, str | None]:
|
||||
if not FA4_AVAILABLE:
|
||||
return False, f"FA4 is unavailable due to: {FA4_UNAVAILABLE_REASON}"
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not (
|
||||
current_platform.is_device_capability_family(90)
|
||||
or current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability_family(110)
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"FA4 is only supported on devices with compute capability 9.x, 10.x, or 11.x",
|
||||
)
|
||||
return True, None
|
||||
|
||||
|
||||
def is_fa_version_supported(fa_version: int) -> bool:
|
||||
if fa_version == 2:
|
||||
return _is_fa2_supported()[0]
|
||||
elif fa_version == 3:
|
||||
return _is_fa3_supported()[0]
|
||||
elif fa_version == 4:
|
||||
return _is_fa4_supported()[0]
|
||||
else:
|
||||
raise ValueError(f"Unsupported FA version: {fa_version}")
|
||||
|
||||
|
||||
def fa_version_unsupported_reason(fa_version: int) -> str | None:
|
||||
if fa_version == 2:
|
||||
return _is_fa2_supported()[1]
|
||||
elif fa_version == 3:
|
||||
return _is_fa3_supported()[1]
|
||||
elif fa_version == 4:
|
||||
return _is_fa4_supported()[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported FA version: {fa_version}")
|
||||
|
||||
|
||||
#
|
||||
# For vLLM we only care about `flash_attn_varlen_func` and
|
||||
# `flash_attn_with_kvcache` so we only maintain wrappers for these two.
|
||||
#
|
||||
|
||||
|
||||
def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
# NOTE only used in FA3
|
||||
def get_scheduler_metadata(
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads_q,
|
||||
num_heads_kv,
|
||||
headdim,
|
||||
cache_seqlens: torch.Tensor,
|
||||
qkv_dtype=torch.bfloat16,
|
||||
headdim_v=None,
|
||||
cu_seqlens_q: torch.Tensor | None = None,
|
||||
cu_seqlens_k_new: torch.Tensor | None = None,
|
||||
cache_leftpad: torch.Tensor | None = None,
|
||||
page_size: int | None = None,
|
||||
max_seqlen_k_new=0,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
has_softcap=False,
|
||||
num_splits=0, # Can be tuned for speed
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
):
|
||||
cache_seqlens = maybe_contiguous(cache_seqlens)
|
||||
if headdim_v is None:
|
||||
headdim_v = headdim
|
||||
scheduler_metadata = torch.ops._vllm_fa3_C.get_scheduler_metadata(
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads_q,
|
||||
num_heads_kv,
|
||||
headdim,
|
||||
headdim_v,
|
||||
qkv_dtype,
|
||||
cache_seqlens,
|
||||
cu_seqlens_q,
|
||||
None, # cu_seqlens_k
|
||||
cu_seqlens_k_new,
|
||||
None, # seqused_q
|
||||
cache_leftpad,
|
||||
page_size,
|
||||
max_seqlen_k_new,
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
has_softcap,
|
||||
num_splits,
|
||||
pack_gqa,
|
||||
sm_margin,
|
||||
)
|
||||
|
||||
return scheduler_metadata
|
||||
|
||||
|
||||
def flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
max_seqlen_q,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_k,
|
||||
cu_seqlens_k=None, # only used for non-paged prefill
|
||||
seqused_k=None,
|
||||
q_v=None,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
window_size: list[int] | None = None,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
block_table=None,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
# FA3 Only
|
||||
scheduler_metadata=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
num_splits: int = 0,
|
||||
# Version selector
|
||||
fa_version: int = DEFAULT_FA_VERSION,
|
||||
s_aux=None,
|
||||
cp_world_size=1,
|
||||
cp_rank=0,
|
||||
cp_tot_seqused_k=None,
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
||||
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||
will only attend to keys between
|
||||
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
assert cu_seqlens_k is not None or seqused_k is not None, (
|
||||
"cu_seqlens_k or seqused_k must be provided"
|
||||
)
|
||||
assert cu_seqlens_k is None or seqused_k is None, (
|
||||
"cu_seqlens_k and seqused_k cannot be provided at the same time"
|
||||
)
|
||||
assert block_table is None or seqused_k is not None, (
|
||||
"seqused_k must be provided if block_table is provided"
|
||||
)
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
# custom op does not support non-tuple input
|
||||
real_window_size: tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1])
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
|
||||
dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
|
||||
|
||||
if fa_version == 2:
|
||||
if (
|
||||
scheduler_metadata is not None
|
||||
and q_descale is not None
|
||||
and k_descale is not None
|
||||
and v_descale is not None
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"FA2 does not support scheduler_metadata, q_descale, "
|
||||
"k_descale, v_descale"
|
||||
)
|
||||
if s_aux is not None:
|
||||
raise NotImplementedError("FA2 does not support s_aux")
|
||||
if num_splits > 1:
|
||||
raise NotImplementedError("FA2 does not support num_splits > 1")
|
||||
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
|
||||
# still wants it so we pass all zeros
|
||||
dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
|
||||
seqused_k,
|
||||
None,
|
||||
block_table,
|
||||
alibi_slopes,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
real_window_size[0],
|
||||
real_window_size[1],
|
||||
softcap,
|
||||
return_softmax_lse and dropout_p > 0,
|
||||
num_splits,
|
||||
None,
|
||||
)
|
||||
elif fa_version == 3:
|
||||
assert alibi_slopes is None, "Alibi is not supported in FA3"
|
||||
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None, # k_new, v_new
|
||||
q_v,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k, # cu_seqlens_k
|
||||
None, # cu_seqlens_k_new
|
||||
None,
|
||||
seqused_k, # seqused_q, seqused_k
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
block_table,
|
||||
None, # kv_batch_idx
|
||||
None, # leftpad_k
|
||||
None,
|
||||
None,
|
||||
None, # rotary_cos, rotary_sin, seqlens_rotary
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
softmax_scale,
|
||||
causal,
|
||||
real_window_size[0],
|
||||
real_window_size[1],
|
||||
softcap,
|
||||
True, # rotary_interleaved
|
||||
scheduler_metadata,
|
||||
num_splits,
|
||||
None, # pack_gqa
|
||||
0, # sm_margin
|
||||
s_aux, # s_aux
|
||||
cp_world_size,
|
||||
cp_rank,
|
||||
cp_tot_seqused_k,
|
||||
)
|
||||
elif fa_version == 4:
|
||||
assert alibi_slopes is None, "Alibi is not supported in FA4"
|
||||
# FA4 on SM90 doesn't support paged KV; SM100+ does
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if block_table is not None and current_platform.is_device_capability_family(90):
|
||||
raise NotImplementedError(
|
||||
"FA4 with paged KV is not supported on SM90 (Hopper). "
|
||||
"Use FA3 or upgrade to Blackwell (SM100+)."
|
||||
)
|
||||
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd
|
||||
|
||||
out, softmax_lse = _flash_attn_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
page_table=block_table,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
softcap=softcap,
|
||||
window_size_left=real_window_size[0] if real_window_size[0] >= 0 else None,
|
||||
window_size_right=real_window_size[1] if real_window_size[1] >= 0 else None,
|
||||
num_splits=num_splits,
|
||||
return_lse=return_softmax_lse,
|
||||
out=out,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FA version: {fa_version}")
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
|
||||
|
||||
def sparse_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k: (batch_size, seqlen, nheads_k, headdim)
|
||||
v: (batch_size, seqlen, nheads_k, headdim)
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops._vllm_fa2_C.fwd_sparse(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
alibi_slopes,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
|
||||
|
||||
def sparse_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd_sparse(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
alibi_slopes,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
Reference in New Issue
Block a user