[Kernel][Misc] register ops to prevent graph breaks (#6917)

Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
bnellnm
2024-09-11 15:52:19 -04:00
committed by GitHub
parent 7015417fd4
commit 73202dbe77
22 changed files with 528 additions and 102 deletions

View File

@@ -6,6 +6,7 @@ import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
@@ -198,6 +199,13 @@ def test_paged_attention(
k_scale,
v_scale,
)
opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
@@ -230,6 +238,14 @@ def test_paged_attention(
k_scale,
v_scale,
)
opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query, key_cache,
value_cache, num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
else:
raise AssertionError(f"Unknown version: {version}")