[1/n] Migrate permute_cols to libtorch stable ABI (#31509)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
@@ -340,7 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
@@ -986,6 +985,48 @@ define_extension_target(
|
||||
# Setting this variable sidesteps the issue by calling the driver directly.
|
||||
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
|
||||
# add OR VLLM_GPU_LANG STREQUAL "HIP" here once
|
||||
# https://github.com/vllm-project/vllm/issues/35163 is resolved
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
#
|
||||
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
|
||||
#
|
||||
set(VLLM_STABLE_EXT_SRC
|
||||
"csrc/libtorch_stable/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/permute_cols.cu")
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_STABLE_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C_stable extension.")
|
||||
define_extension_target(
|
||||
_C_stable_libtorch
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${VLLM_STABLE_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
|
||||
# This ensures we only use C-shim APIs available in PyTorch 2.10.
|
||||
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
|
||||
# which is currently set to 2.10.
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
TORCH_TARGET_VERSION=0x020A000000000000ULL)
|
||||
|
||||
# Needed to use cuda APIs from C-shim
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
USE_CUDA)
|
||||
endif()
|
||||
|
||||
#
|
||||
# _moe_C extension
|
||||
#
|
||||
|
||||
9
csrc/libtorch_stable/ops.h
Normal file
9
csrc/libtorch_stable/ops.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& perm);
|
||||
#endif
|
||||
@@ -1,10 +1,13 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "torch_utils.h"
|
||||
|
||||
static constexpr int default_threads = 256;
|
||||
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
@@ -64,19 +67,22 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
|
||||
// More efficient version of A[..., perm]
|
||||
// taken from gptq_marlin.cu
|
||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
auto dev = A.get_device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& perm) {
|
||||
const int32_t dev = A.get_device_index();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(dev);
|
||||
const auto stream = get_current_cuda_stream(dev);
|
||||
|
||||
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
|
||||
"Currently only 16bit types are supported");
|
||||
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
|
||||
TORCH_CHECK(A.size(-1) % 8 == 0,
|
||||
"A columns must be a multiple of 8 (128bits)");
|
||||
auto A_2d = A.view({-1, A.size(-1)});
|
||||
STD_TORCH_CHECK(
|
||||
A.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||
A.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||
"Currently only 16bit types are supported");
|
||||
STD_TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
|
||||
STD_TORCH_CHECK(A.size(-1) % 8 == 0,
|
||||
"A columns must be a multiple of 8 (128bits)");
|
||||
auto A_2d = torch::stable::view(A, {-1, A.size(-1)});
|
||||
|
||||
torch::Tensor D = torch::empty_like(A);
|
||||
torch::stable::Tensor D = torch::stable::empty_like(A);
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
||||
int block_rows = div_ceil(A_2d.size(0), sms);
|
||||
21
csrc/libtorch_stable/torch_bindings.cpp
Normal file
21
csrc/libtorch_stable/torch_bindings.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#include "ops.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
|
||||
// Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility.
|
||||
// Note: We register under namespace "_C" so ops are accessible as
|
||||
// torch.ops._C.<op_name> for compatibility with existing code.
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) {
|
||||
#ifndef USE_ROCM
|
||||
m.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
#endif
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
#ifndef USE_ROCM
|
||||
m.impl("permute_cols", TORCH_BOX(&permute_cols));
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_C_stable_libtorch)
|
||||
13
csrc/libtorch_stable/torch_utils.h
Normal file
13
csrc/libtorch_stable/torch_utils.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// Utility to get the current CUDA stream for a given device using stable APIs.
|
||||
// Returns a cudaStream_t for use in kernel launches.
|
||||
inline cudaStream_t get_current_cuda_stream(int32_t device_index) {
|
||||
void* stream_ptr = nullptr;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
|
||||
return reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
}
|
||||
@@ -201,7 +201,6 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||
torch::Tensor _zeros, int64_t split_k_iters,
|
||||
int64_t thx, int64_t thy);
|
||||
|
||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
||||
#endif
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||
|
||||
@@ -303,9 +303,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
") -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
||||
|
||||
// Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4).
|
||||
ops.def(
|
||||
"marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
|
||||
5
setup.py
5
setup.py
@@ -597,6 +597,7 @@ class precompiled_wheel_utils:
|
||||
with zipfile.ZipFile(wheel_path) as wheel:
|
||||
files_to_copy = [
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_C_stable_libtorch.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/_flashmla_C.abi3.so",
|
||||
"vllm/_flashmla_extension_C.abi3.so",
|
||||
@@ -932,6 +933,10 @@ if _is_cpu():
|
||||
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
# also _is_hip() once https://github.com/vllm-project/vllm/issues/35163 is
|
||||
# fixed
|
||||
if _is_cuda():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C_stable_libtorch"))
|
||||
|
||||
package_data = {
|
||||
"vllm": [
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm._C_stable_libtorch # noqa
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import import_pynvml
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
Reference in New Issue
Block a user