[Misc] Add indirection layer for custom ops (#3913)

This commit is contained in:
Kunshang Ji
2024-04-11 03:26:07 +00:00
committed by GitHub
parent e42df7227d
commit e9da5a40c6
14 changed files with 224 additions and 32 deletions

View File

@@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm._C import cache_ops, ops
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
@@ -237,14 +237,14 @@ def test_paged_attention(
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
cache_ops.convert_fp8(key_cache, dequantized_key_cache)
ops.convert_fp8(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
cache_ops.convert_fp8(value_cache, dequantized_value_cache)
ops.convert_fp8(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)