[misc] add forward context for attention (#9029)
This commit is contained in:
@@ -3,9 +3,9 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.attention.backends.flash_attn # noqa: F401
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@@ -112,10 +112,10 @@ def test_flash_attn_with_paged_kv(
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
output = torch.ops.vllm.flash_attn_with_kvcache(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
output = flash_attn_with_kvcache(
|
||||
q=query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
@@ -123,25 +123,6 @@ def test_flash_attn_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
).squeeze(1)
|
||||
|
||||
if num_blocks <= 2048:
|
||||
test_utils = ["test_faketensor", "test_schema"]
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
opcheck(torch.ops.vllm.flash_attn_with_kvcache,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
decode_query=query.unsqueeze(1),
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
@@ -213,7 +194,7 @@ def test_varlen_with_paged_kv(
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
output = torch.ops.vllm.flash_attn_varlen_func(
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
@@ -228,29 +209,6 @@ def test_varlen_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
)
|
||||
|
||||
if num_blocks <= 2048:
|
||||
test_utils = ["test_faketensor", "test_schema"]
|
||||
else:
|
||||
test_utils = ["test_faketensor"]
|
||||
|
||||
opcheck(torch.ops.vllm.flash_attn_varlen_func,
|
||||
args=tuple(),
|
||||
kwargs=dict(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
),
|
||||
test_utils=test_utils)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
|
||||
Reference in New Issue
Block a user