[Attention] Flash Attention 3 - fp8 (#14570)
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
@@ -15,6 +15,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
QDTYPES = [None, torch.float8_e4m3fn]
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
@@ -85,6 +86,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_flash_attn_with_paged_kv(
|
||||
use_out: bool,
|
||||
@@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv(
|
||||
num_blocks: int,
|
||||
sliding_window: Optional[int],
|
||||
fa_version: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||
pytest.skip("Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
@@ -130,10 +136,28 @@ def test_flash_attn_with_paged_kv(
|
||||
|
||||
q = query.unsqueeze(1)
|
||||
out = torch.empty_like(q) if use_out else None
|
||||
|
||||
maybe_quantized_query = q
|
||||
maybe_quantized_key_cache = key_cache
|
||||
maybe_quantized_value_cache = value_cache
|
||||
q_descale = None
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
if q_dtype is not None:
|
||||
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||
maybe_quantized_query = query.to(q_dtype)
|
||||
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||
|
||||
scale_shape = (num_seqs, num_kv_heads)
|
||||
q_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
|
||||
output = flash_attn_with_kvcache(
|
||||
q=q,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
q=maybe_quantized_query,
|
||||
k_cache=maybe_quantized_key_cache,
|
||||
v_cache=maybe_quantized_value_cache,
|
||||
out=out,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
@@ -142,10 +166,17 @@ def test_flash_attn_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
window_size=window_size,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
output = output if not use_out else out
|
||||
output = output.squeeze(1)
|
||||
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
@@ -155,7 +186,7 @@ def test_flash_attn_with_paged_kv(
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window)
|
||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
|
||||
@@ -171,6 +202,7 @@ def test_flash_attn_with_paged_kv(
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_varlen_with_paged_kv(
|
||||
use_out: bool,
|
||||
@@ -183,11 +215,15 @@ def test_varlen_with_paged_kv(
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
fa_version: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||
pytest.skip("Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type")
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
@@ -223,10 +259,28 @@ def test_varlen_with_paged_kv(
|
||||
dtype=torch.int32)
|
||||
|
||||
out = torch.empty_like(query) if use_out else None
|
||||
|
||||
maybe_quantized_query = query
|
||||
maybe_quantized_key_cache = key_cache
|
||||
maybe_quantized_value_cache = value_cache
|
||||
q_descale = None
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
if q_dtype is not None:
|
||||
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||
maybe_quantized_query = query.to(q_dtype)
|
||||
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||
|
||||
scale_shape = (num_seqs, num_kv_heads)
|
||||
q_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
q=maybe_quantized_query,
|
||||
k=maybe_quantized_key_cache,
|
||||
v=maybe_quantized_value_cache,
|
||||
out=out,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens,
|
||||
@@ -238,6 +292,9 @@ def test_varlen_with_paged_kv(
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
output = output if not use_out else out
|
||||
|
||||
@@ -252,5 +309,8 @@ def test_varlen_with_paged_kv(
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
Reference in New Issue
Block a user