diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp index 6d062c71e..12be43360 100644 --- a/csrc/cpu/dnnl_kernels.cpp +++ b/csrc/cpu/dnnl_kernels.cpp @@ -360,13 +360,14 @@ void onednn_scaled_mm( const std::optional& azp, // [M] or [1] const std::optional& azp_adj, // [M] or [1] const std::optional& bias, // [N] - int64_t handler) { + const torch::Tensor& handler_tensor) { CPU_KERNEL_GUARD_IN(onednn_scaled_mm) TORCH_CHECK(a.dim() == 2); TORCH_CHECK(a.is_contiguous()); TORCH_CHECK(c.is_contiguous()); W8A8MatMulPrimitiveHandler* ptr = - reinterpret_cast(handler); + reinterpret_cast( + handler_tensor.item()); const int32_t* azp_ptr = nullptr; if (azp.has_value()) { azp_ptr = azp->data_ptr(); @@ -519,13 +520,14 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b, void onednn_mm(torch::Tensor& c, // [M, OC], row-major const torch::Tensor& a, // [M, IC], row-major - const std::optional& bias, int64_t handler) { + const std::optional& bias, + const torch::Tensor& handler_tensor) { CPU_KERNEL_GUARD_IN(onednn_mm) TORCH_CHECK(a.dim() == 2); TORCH_CHECK(a.stride(-1) == 1); TORCH_CHECK(c.stride(-1) == 1); MatMulPrimitiveHandler* ptr = - reinterpret_cast(handler); + reinterpret_cast(handler_tensor.item()); // ACL matmuls expect contiguous source tensors #ifdef VLLM_USE_ACL diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 19068565d..62c233892 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -19,13 +19,14 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, const std::optional& azp, const std::optional& azp_adj, const std::optional& bias, - int64_t handler); + const torch::Tensor& handler_tensor); int64_t create_onednn_mm_handler(const torch::Tensor& b, int64_t primitive_cache_size); void onednn_mm(torch::Tensor& c, const torch::Tensor& a, - const std::optional& bias, int64_t handler); + const std::optional& bias, + const torch::Tensor& handler_tensor); bool is_onednn_acl_supported(); @@ -196,7 +197,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // oneDNN GEMM ops.def( "onednn_mm(Tensor! c, Tensor a, Tensor? bias, " - "int handler) -> ()"); + "Tensor handler_tensor) -> ()"); ops.impl("onednn_mm", torch::kCPU, &onednn_mm); // Check if oneDNN was built with ACL backend @@ -212,7 +213,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization ops.def( "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, " - "Tensor? azp_adj, Tensor? bias, int handler) -> ()"); + "Tensor? azp_adj, Tensor? bias, Tensor handler_tensor) -> ()"); ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm); // Compute int8 quantized tensor for given scaling factor. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 20f399d7f..01067ca32 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2845,13 +2845,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): class CPUDNNLGEMMHandler: def __init__(self) -> None: - self.handler: int | None = None + self.handler_tensor: torch.Tensor | None = None self.n = -1 self.k = -1 def __del__(self): - if self.handler is not None: - torch.ops._C.release_dnnl_matmul_handler(self.handler) + if self.handler_tensor is not None: + torch.ops._C.release_dnnl_matmul_handler(self.handler_tensor.item()) _supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler")) @@ -2867,8 +2867,10 @@ def create_onednn_mm( ) -> CPUDNNLGEMMHandler: handler = CPUDNNLGEMMHandler() handler.k, handler.n = weight.size() - handler.handler = torch.ops._C.create_onednn_mm_handler( - weight, primitive_cache_size + # store the handler pointer in a tensor it doesn't get inlined + handler.handler_tensor = torch.tensor( + torch.ops._C.create_onednn_mm_handler(weight, primitive_cache_size), + dtype=torch.int64, ) return handler @@ -2880,7 +2882,7 @@ def onednn_mm( ) -> torch.Tensor: output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) torch.ops._C.onednn_mm( - output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler + output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler_tensor ) return output @@ -2896,8 +2898,17 @@ def create_onednn_scaled_mm( ) -> CPUDNNLGEMMHandler: handler = CPUDNNLGEMMHandler() handler.k, handler.n = weight.size() - handler.handler = torch.ops._C.create_onednn_scaled_mm_handler( - weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size + # store the handler pointer in a tensor so it doesn't get inlined + handler.handler_tensor = torch.tensor( + torch.ops._C.create_onednn_scaled_mm_handler( + weight, + weight_scales, + output_type, + dynamic_quant, + use_azp, + primitive_cache_size, + ), + dtype=torch.int64, ) return handler @@ -2950,7 +2961,13 @@ def onednn_scaled_mm( bias: torch.Tensor | None, ) -> torch.Tensor: torch.ops._C.onednn_scaled_mm( - output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler + output, + x, + input_scale, + input_zp, + input_zp_adj, + bias, + dnnl_handler.handler_tensor, ) return output diff --git a/vllm/envs.py b/vllm/envs.py index d2246f935..02f0ec8e1 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -289,16 +289,11 @@ def use_aot_compile() -> bool: from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) - from vllm.platforms import current_platform from vllm.utils.torch_utils import is_torch_equal_or_newer default_value = ( "1" - if is_torch_equal_or_newer("2.10.0.dev") - and not disable_compile_cache() - # Disabling AOT_COMPILE for CPU - # See: https://github.com/vllm-project/vllm/issues/32033 - and not current_platform.is_cpu() + if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache() else "0" )