[Attention] Flash Attention 3 - fp8 (#14570)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
Mickaël Seznec
2025-03-20 06:14:20 +01:00
committed by GitHub
parent ae65f3e237
commit a597a57595
15 changed files with 272 additions and 76 deletions

View File

@@ -6,11 +6,12 @@ from typing import TYPE_CHECKING, Any, Optional
import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
@@ -226,6 +227,9 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
@@ -259,6 +263,17 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
if not attn_metadata.use_cascade:
@@ -279,6 +294,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output
@@ -301,6 +319,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
)
return output
@@ -391,6 +412,9 @@ def cascade_attention(
block_table: torch.Tensor,
common_prefix_len: int,
fa_version: int,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
# TODO: Support sliding window.
@@ -402,6 +426,7 @@ def cascade_attention(
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
@@ -419,8 +444,16 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
@@ -437,6 +470,12 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
)
# Merge prefix and suffix outputs, and store the result in output.