diff --git a/CMakeLists.txt b/CMakeLists.txt index 693070b5f..ddc9bcadb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -340,7 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" - "csrc/permute_cols.cu" "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" @@ -986,6 +985,48 @@ define_extension_target( # Setting this variable sidesteps the issue by calling the driver directly. target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) +# add OR VLLM_GPU_LANG STREQUAL "HIP" here once +# https://github.com/vllm-project/vllm/issues/35163 is resolved +if(VLLM_GPU_LANG STREQUAL "CUDA") + # + # _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY) + # + set(VLLM_STABLE_EXT_SRC + "csrc/libtorch_stable/torch_bindings.cpp") + + if(VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/permute_cols.cu") + endif() + + if(VLLM_GPU_LANG STREQUAL "CUDA") + set_gencode_flags_for_srcs( + SRCS "${VLLM_STABLE_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + endif() + + message(STATUS "Enabling C_stable extension.") + define_extension_target( + _C_stable_libtorch + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_STABLE_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) + + # Set TORCH_TARGET_VERSION for stable ABI compatibility. + # This ensures we only use C-shim APIs available in PyTorch 2.10. + # _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION + # which is currently set to 2.10. + target_compile_definitions(_C_stable_libtorch PRIVATE + TORCH_TARGET_VERSION=0x020A000000000000ULL) + + # Needed to use cuda APIs from C-shim + target_compile_definitions(_C_stable_libtorch PRIVATE + USE_CUDA) +endif() + # # _moe_C extension # diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h new file mode 100644 index 000000000..5fe1492b8 --- /dev/null +++ b/csrc/libtorch_stable/ops.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +#ifndef USE_ROCM +torch::stable::Tensor permute_cols(torch::stable::Tensor const& A, + torch::stable::Tensor const& perm); +#endif diff --git a/csrc/permute_cols.cu b/csrc/libtorch_stable/permute_cols.cu similarity index 68% rename from csrc/permute_cols.cu rename to csrc/libtorch_stable/permute_cols.cu index f51fa7329..3162ac02c 100644 --- a/csrc/permute_cols.cu +++ b/csrc/libtorch_stable/permute_cols.cu @@ -1,10 +1,13 @@ -#include - -#include -#include +#include +#include +#include +#include +#include #include +#include "torch_utils.h" + static constexpr int default_threads = 256; static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } @@ -64,19 +67,22 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, // More efficient version of A[..., perm] // taken from gptq_marlin.cu -torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - auto dev = A.get_device(); - auto stream = at::cuda::getCurrentCUDAStream(dev); +torch::stable::Tensor permute_cols(torch::stable::Tensor const& A, + torch::stable::Tensor const& perm) { + const int32_t dev = A.get_device_index(); + const torch::stable::accelerator::DeviceGuard device_guard(dev); + const auto stream = get_current_cuda_stream(dev); - TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, - "Currently only 16bit types are supported"); - TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); - TORCH_CHECK(A.size(-1) % 8 == 0, - "A columns must be a multiple of 8 (128bits)"); - auto A_2d = A.view({-1, A.size(-1)}); + STD_TORCH_CHECK( + A.scalar_type() == torch::headeronly::ScalarType::Half || + A.scalar_type() == torch::headeronly::ScalarType::BFloat16, + "Currently only 16bit types are supported"); + STD_TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); + STD_TORCH_CHECK(A.size(-1) % 8 == 0, + "A columns must be a multiple of 8 (128bits)"); + auto A_2d = torch::stable::view(A, {-1, A.size(-1)}); - torch::Tensor D = torch::empty_like(A); + torch::stable::Tensor D = torch::stable::empty_like(A); int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); int block_rows = div_ceil(A_2d.size(0), sms); diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp new file mode 100644 index 000000000..0c0ecaa01 --- /dev/null +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -0,0 +1,21 @@ +#include "ops.h" +#include "core/registration.h" + +#include + +// Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility. +// Note: We register under namespace "_C" so ops are accessible as +// torch.ops._C. for compatibility with existing code. +STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) { +#ifndef USE_ROCM + m.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); +#endif +} + +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { +#ifndef USE_ROCM + m.impl("permute_cols", TORCH_BOX(&permute_cols)); +#endif +} + +REGISTER_EXTENSION(_C_stable_libtorch) diff --git a/csrc/libtorch_stable/torch_utils.h b/csrc/libtorch_stable/torch_utils.h new file mode 100644 index 000000000..a615768a9 --- /dev/null +++ b/csrc/libtorch_stable/torch_utils.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +// Utility to get the current CUDA stream for a given device using stable APIs. +// Returns a cudaStream_t for use in kernel launches. +inline cudaStream_t get_current_cuda_stream(int32_t device_index) { + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); + return reinterpret_cast(stream_ptr); +} diff --git a/csrc/ops.h b/csrc/ops.h index 4d33d86d9..26caf7f7d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -201,7 +201,6 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, int64_t thy); -torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); #endif torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b29e38c7c..81605d002 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -303,9 +303,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> Tensor"); // conditionally compiled so impl registration is in source file - ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); - ops.impl("permute_cols", torch::kCUDA, &permute_cols); - // Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4). ops.def( "marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " diff --git a/setup.py b/setup.py index a809c66c8..68861fe4b 100644 --- a/setup.py +++ b/setup.py @@ -597,6 +597,7 @@ class precompiled_wheel_utils: with zipfile.ZipFile(wheel_path) as wheel: files_to_copy = [ "vllm/_C.abi3.so", + "vllm/_C_stable_libtorch.abi3.so", "vllm/_moe_C.abi3.so", "vllm/_flashmla_C.abi3.so", "vllm/_flashmla_extension_C.abi3.so", @@ -932,6 +933,10 @@ if _is_cpu(): if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) + # also _is_hip() once https://github.com/vllm-project/vllm/issues/35163 is + # fixed + if _is_cuda(): + ext_modules.append(CMakeExtension(name="vllm._C_stable_libtorch")) package_data = { "vllm": [ diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 6e4eb0993..7070fd0b6 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -19,6 +19,7 @@ from typing_extensions import ParamSpec # import custom ops, trigger op registration import vllm._C # noqa +import vllm._C_stable_libtorch # noqa from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml from vllm.utils.torch_utils import cuda_device_count_stateless