[Kernel] Flash Attention 3 Support (#12093)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-01-23 09:45:48 -05:00
committed by GitHub
parent c5b4b11d7f
commit 978b45f399
8 changed files with 151 additions and 83 deletions

View File

@@ -9,8 +9,11 @@ import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
is_fa_version_supported)
class FlashAttentionBackend(AttentionBackend):
@@ -63,7 +66,7 @@ class FlashAttentionMetadata:
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_start_loc: torch.Tensor
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
@@ -71,8 +74,8 @@ class FlashAttentionMetadata:
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
cu_prefix_kv_lens: Optional[torch.Tensor]
cu_suffix_kv_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
@@ -128,6 +131,20 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for "
"FlashAttentionImpl")
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
assert is_fa_version_supported(self.fa_version)
def forward(
self,
layer: torch.nn.Module,
@@ -196,7 +213,7 @@ class FlashAttentionImpl(AttentionImpl):
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
seqused_k=attn_metadata.seq_lens,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
@@ -204,6 +221,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
fa_version=self.fa_version,
)
return output
@@ -216,8 +234,8 @@ class FlashAttentionImpl(AttentionImpl):
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
cu_prefix_kv_lens=attn_metadata.cu_prefix_kv_lens,
cu_suffix_kv_lens=attn_metadata.cu_suffix_kv_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
@@ -225,6 +243,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.fa_version,
)
return output
@@ -305,8 +324,8 @@ def cascade_attention(
cu_query_lens: torch.Tensor,
max_query_len: int,
cu_prefix_query_lens: torch.Tensor,
cu_prefix_kv_lens: torch.Tensor,
cu_suffix_kv_lens: torch.Tensor,
prefix_kv_lens: torch.Tensor,
suffix_kv_lens: torch.Tensor,
max_kv_len: int,
softmax_scale: float,
alibi_slopes: Optional[torch.Tensor],
@@ -314,6 +333,7 @@ def cascade_attention(
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
fa_version: int,
) -> torch.Tensor:
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
# TODO: Support sliding window.
@@ -332,7 +352,7 @@ def cascade_attention(
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
cu_seqlens_k=cu_prefix_kv_lens,
seqused_k=prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale,
@@ -341,6 +361,7 @@ def cascade_attention(
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
fa_version=fa_version,
)
# Process suffix per query.
@@ -349,7 +370,7 @@ def cascade_attention(
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_suffix_kv_lens,
seqused_k=suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale,
@@ -358,6 +379,7 @@ def cascade_attention(
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
fa_version=fa_version,
)
# Merge prefix and suffix outputs, and store the result in output.