[Misc] Add indirection layer for custom ops (#3913)
This commit is contained in:
@@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm._C import cache_ops, ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
@@ -237,14 +237,14 @@ def test_paged_attention(
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
cache_ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
cache_ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
|
||||
Reference in New Issue
Block a user