register custom op for flash attn and use from torch.ops (#7536)
This commit is contained in:
@@ -2,13 +2,16 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
|
||||
import vllm.attention.backends.flash_attn # noqa: F401
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
@@ -72,6 +75,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@torch.inference_mode()
|
||||
def test_flash_attn_with_paged_kv(
|
||||
kv_lens: List[int],
|
||||
@@ -80,6 +84,7 @@ def test_flash_attn_with_paged_kv(
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
@@ -91,7 +96,7 @@ def test_flash_attn_with_paged_kv(
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(NUM_BLOCKS,
|
||||
key_cache = torch.randn(num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
@@ -101,14 +106,14 @@ def test_flash_attn_with_paged_kv(
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
num_blocks,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
output = flash_attn_with_kvcache(
|
||||
q=query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
output = torch.ops.vllm.flash_attn_with_kvcache(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
@@ -116,6 +121,25 @@ def test_flash_attn_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
).squeeze(1)
|
||||
|
||||
if num_blocks <= 2048:
|
||||
test_utils = ["test_faketensor", "test_schema"]
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
@@ -137,6 +161,7 @@ def test_flash_attn_with_paged_kv(
|
||||
@pytest.mark.parametrize("sliding_window", [None])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@torch.inference_mode()
|
||||
def test_varlen_with_paged_kv(
|
||||
seq_lens: List[Tuple[int, int]],
|
||||
@@ -146,6 +171,7 @@ def test_varlen_with_paged_kv(
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
@@ -166,7 +192,7 @@ def test_varlen_with_paged_kv(
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache = torch.randn(NUM_BLOCKS,
|
||||
key_cache = torch.randn(num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
@@ -181,11 +207,11 @@ def test_varlen_with_paged_kv(
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
num_blocks,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
@@ -200,6 +226,29 @@ def test_varlen_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
)
|
||||
|
||||
if num_blocks <= 2048:
|
||||
test_utils = ["test_faketensor", "test_schema"]
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
|
||||
Reference in New Issue
Block a user