[Attention] Flash Attention 3 - fp8 (#14570)
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
@@ -6,11 +6,12 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
@@ -226,6 +227,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
@@ -259,6 +263,17 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
|
||||
key.shape[1])
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
if not attn_metadata.use_cascade:
|
||||
@@ -279,6 +294,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -301,6 +319,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale,
|
||||
k_descale=layer._k_scale,
|
||||
v_descale=layer._v_scale,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -391,6 +412,9 @@ def cascade_attention(
|
||||
block_table: torch.Tensor,
|
||||
common_prefix_len: int,
|
||||
fa_version: int,
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
||||
# TODO: Support sliding window.
|
||||
@@ -402,6 +426,7 @@ def cascade_attention(
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
assert num_common_kv_blocks > 0
|
||||
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
# Process shared prefix.
|
||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||
@@ -419,8 +444,16 @@ def cascade_attention(
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale.expand(descale_shape)
|
||||
if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape)
|
||||
if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape)
|
||||
if v_descale is not None else None,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
# Process suffix per query.
|
||||
suffix_output, suffix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
@@ -437,6 +470,12 @@ def cascade_attention(
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale.expand(descale_shape)
|
||||
if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape)
|
||||
if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape)
|
||||
if v_descale is not None else None,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
|
||||
Reference in New Issue
Block a user