[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||
|
||||
@@ -198,6 +199,13 @@ def test_paged_attention(
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.paged_attention_v1,
|
||||
(output, query, key_cache, value_cache, num_kv_heads, scale,
|
||||
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
|
||||
elif version == "v2":
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
@@ -230,6 +238,14 @@ def test_paged_attention(
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.paged_attention_v2,
|
||||
(output, exp_sums, max_logits, tmp_output, query, key_cache,
|
||||
value_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
|
||||
k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user