Replace FlashAttention with xformers (#70)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import random
|
||||
from typing import List, Optional
|
||||
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from cacheflow import attention_ops
|
||||
|
||||
@@ -81,8 +82,10 @@ def ref_multi_query_kv_attention(
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
|
||||
# Create attention mask.
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
@@ -160,21 +163,20 @@ def test_single_query_cached_kv_attention(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
qkv = torch.randn(
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(
|
||||
key_cache = torch.empty(
|
||||
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||
key_cache.uniform_(-1e-3, 1e-3)
|
||||
value_block_shape = (num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
value_cache = torch.empty(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
|
||||
# Adjust the range of the values to reduce precision errors.
|
||||
query = query / (head_size ** 0.5)
|
||||
key_cache = key_cache / (head_size ** 0.5)
|
||||
value_cache = value_cache / (head_size ** 0.5)
|
||||
value_cache.uniform_(-1e-3, 1e-3)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
max_context_len = max(context_lens)
|
||||
@@ -228,39 +230,30 @@ def test_multi_query_kv_attention(
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
max_seq_len = max(seq_lens)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||
output = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
output = output.squeeze(0)
|
||||
|
||||
cu_seq_lens = [0]
|
||||
for seq_len in seq_lens:
|
||||
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
# Adjust the range of the values to reduce precision errors.
|
||||
qkv = qkv / (head_size ** 0.5)
|
||||
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cu_seq_lens,
|
||||
cu_seq_lens,
|
||||
max_seq_len,
|
||||
max_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
return_softmax=False,
|
||||
)
|
||||
|
||||
cu_seq_lens = cu_seq_lens.cpu().tolist()
|
||||
ref_output = ref_multi_query_kv_attention(
|
||||
cu_seq_lens,
|
||||
query,
|
||||
@@ -277,8 +270,8 @@ def test_attention(seed: int) -> None:
|
||||
# the test fails due to the precision issue. Re-run the test if it fails.
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for block_size in [8, 16, 32, 64]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
@@ -292,14 +285,12 @@ def test_attention(seed: int) -> None:
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): FlashAttention does not support FP32.
|
||||
for dtype in [torch.half]:
|
||||
# NOTE(woosuk): FlashAttention does not support head_size > 128.
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
||||
f'head_size={head_size}')
|
||||
test_multi_query_kv_attention(
|
||||
num_seqs=11,
|
||||
num_seqs=5,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
|
||||
Reference in New Issue
Block a user