[Kernel] Flash Attention 3 Support (#12093)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -9,8 +9,11 @@ import triton.language as tl
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
is_fa_version_supported)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
@@ -63,7 +66,7 @@ class FlashAttentionMetadata:
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
@@ -71,8 +74,8 @@ class FlashAttentionMetadata:
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
cu_prefix_kv_lens: Optional[torch.Tensor]
|
||||
cu_suffix_kv_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
@@ -128,6 +131,20 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
|
||||
# if hopper default to FA3, otherwise stick to FA2 for now
|
||||
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
||||
# use FA3 as default for both
|
||||
if current_platform.get_device_capability()[0] >= 9:
|
||||
self.fa_version = 3 if is_fa_version_supported(3) else 2
|
||||
else:
|
||||
self.fa_version = 2
|
||||
|
||||
if VLLM_FLASH_ATTN_VERSION is not None:
|
||||
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
||||
|
||||
assert is_fa_version_supported(self.fa_version)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -196,7 +213,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_query_len,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
seqused_k=attn_metadata.seq_lens,
|
||||
max_seqlen_k=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
@@ -204,6 +221,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
window_size=self.sliding_window,
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.fa_version,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -216,8 +234,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
cu_query_lens=attn_metadata.query_start_loc,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
|
||||
cu_prefix_kv_lens=attn_metadata.cu_prefix_kv_lens,
|
||||
cu_suffix_kv_lens=attn_metadata.cu_suffix_kv_lens,
|
||||
prefix_kv_lens=attn_metadata.prefix_kv_lens,
|
||||
suffix_kv_lens=attn_metadata.suffix_kv_lens,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
@@ -225,6 +243,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
fa_version=self.fa_version,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -305,8 +324,8 @@ def cascade_attention(
|
||||
cu_query_lens: torch.Tensor,
|
||||
max_query_len: int,
|
||||
cu_prefix_query_lens: torch.Tensor,
|
||||
cu_prefix_kv_lens: torch.Tensor,
|
||||
cu_suffix_kv_lens: torch.Tensor,
|
||||
prefix_kv_lens: torch.Tensor,
|
||||
suffix_kv_lens: torch.Tensor,
|
||||
max_kv_len: int,
|
||||
softmax_scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
@@ -314,6 +333,7 @@ def cascade_attention(
|
||||
logits_soft_cap: float,
|
||||
block_table: torch.Tensor,
|
||||
common_prefix_len: int,
|
||||
fa_version: int,
|
||||
) -> torch.Tensor:
|
||||
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
||||
# TODO: Support sliding window.
|
||||
@@ -332,7 +352,7 @@ def cascade_attention(
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_prefix_query_lens,
|
||||
cu_seqlens_k=cu_prefix_kv_lens,
|
||||
seqused_k=prefix_kv_lens,
|
||||
max_seqlen_q=num_tokens,
|
||||
max_seqlen_k=common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
@@ -341,6 +361,7 @@ def cascade_attention(
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
fa_version=fa_version,
|
||||
)
|
||||
|
||||
# Process suffix per query.
|
||||
@@ -349,7 +370,7 @@ def cascade_attention(
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_suffix_kv_lens,
|
||||
seqused_k=suffix_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len - common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
@@ -358,6 +379,7 @@ def cascade_attention(
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
fa_version=fa_version,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
|
||||
Reference in New Issue
Block a user