[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List, Tuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
@@ -87,6 +88,11 @@ def test_copy_blocks(
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device=device).view(-1, 2)
|
||||
|
||||
opcheck(torch.ops._C_cache_ops.copy_blocks,
|
||||
(key_caches, value_caches, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
|
||||
|
||||
# Run the reference implementation.
|
||||
@@ -162,6 +168,10 @@ def test_reshape_and_cache(
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
|
||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale, v_scale)
|
||||
|
||||
@@ -269,6 +279,10 @@ def test_reshape_and_cache_flash(
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale, v_scale)
|
||||
|
||||
@@ -366,6 +380,14 @@ def test_swap_blocks(
|
||||
src_value_caches_clone = src_value_caches[0].clone()
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
do_opcheck = (head_size == HEAD_SIZES[0])
|
||||
opcheck(torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
|
||||
cond=do_opcheck)
|
||||
opcheck(torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
|
||||
cond=do_opcheck)
|
||||
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
|
||||
block_mapping_tensor)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||
|
||||
Reference in New Issue
Block a user