[Attention] Use FA3 for MLA on Hopper (#12807)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-02-06 06:43:12 -05:00
committed by GitHub
parent cefd56ee35
commit c786e757fa
4 changed files with 51 additions and 59 deletions

View File

@@ -10,13 +10,10 @@ import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)
from vllm.vllm_flash_attn import flash_attn_varlen_func
logger = init_logger(__name__)
@@ -136,25 +133,6 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for "
"FlashAttentionImpl")
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version)
def forward(
self,
layer: torch.nn.Module,
@@ -227,7 +205,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
return output
@@ -249,7 +227,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.fa_version,
fa_version=VLLM_FLASH_ATTN_VERSION,
)
return output