[1/n] Migrate permute_cols to libtorch stable ABI (#31509)

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
mikaylagawarecki
2026-03-19 11:27:26 -04:00
committed by GitHub
parent 104605cbf2
commit 8b10e4fb31
9 changed files with 112 additions and 20 deletions

View File

@@ -340,7 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/permute_cols.cu"
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
@@ -986,6 +985,48 @@ define_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
# add OR VLLM_GPU_LANG STREQUAL "HIP" here once
# https://github.com/vllm-project/vllm/issues/35163 is resolved
if(VLLM_GPU_LANG STREQUAL "CUDA")
#
# _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY)
#
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/permute_cols.cu")
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
set_gencode_flags_for_srcs(
SRCS "${VLLM_STABLE_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
endif()
message(STATUS "Enabling C_stable extension.")
define_extension_target(
_C_stable_libtorch
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_STABLE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
# This ensures we only use C-shim APIs available in PyTorch 2.10.
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
# which is currently set to 2.10.
target_compile_definitions(_C_stable_libtorch PRIVATE
TORCH_TARGET_VERSION=0x020A000000000000ULL)
# Needed to use cuda APIs from C-shim
target_compile_definitions(_C_stable_libtorch PRIVATE
USE_CUDA)
endif()
#
# _moe_C extension
#

View File

@@ -0,0 +1,9 @@
#pragma once
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#ifndef USE_ROCM
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
torch::stable::Tensor const& perm);
#endif

View File

@@ -1,10 +1,13 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include <cuda_fp16.h>
#include "torch_utils.h"
static constexpr int default_threads = 256;
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
@@ -64,19 +67,22 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
// More efficient version of A[..., perm]
// taken from gptq_marlin.cu
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
auto dev = A.get_device();
auto stream = at::cuda::getCurrentCUDAStream(dev);
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
torch::stable::Tensor const& perm) {
const int32_t dev = A.get_device_index();
const torch::stable::accelerator::DeviceGuard device_guard(dev);
const auto stream = get_current_cuda_stream(dev);
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
"Currently only 16bit types are supported");
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
TORCH_CHECK(A.size(-1) % 8 == 0,
"A columns must be a multiple of 8 (128bits)");
auto A_2d = A.view({-1, A.size(-1)});
STD_TORCH_CHECK(
A.scalar_type() == torch::headeronly::ScalarType::Half ||
A.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Currently only 16bit types are supported");
STD_TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
STD_TORCH_CHECK(A.size(-1) % 8 == 0,
"A columns must be a multiple of 8 (128bits)");
auto A_2d = torch::stable::view(A, {-1, A.size(-1)});
torch::Tensor D = torch::empty_like(A);
torch::stable::Tensor D = torch::stable::empty_like(A);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
int block_rows = div_ceil(A_2d.size(0), sms);

View File

@@ -0,0 +1,21 @@
#include "ops.h"
#include "core/registration.h"
#include <torch/csrc/stable/library.h>
// Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility.
// Note: We register under namespace "_C" so ops are accessible as
// torch.ops._C.<op_name> for compatibility with existing code.
STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) {
#ifndef USE_ROCM
m.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
}
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
#ifndef USE_ROCM
m.impl("permute_cols", TORCH_BOX(&permute_cols));
#endif
}
REGISTER_EXTENSION(_C_stable_libtorch)

View File

@@ -0,0 +1,13 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <cuda_runtime.h>
// Utility to get the current CUDA stream for a given device using stable APIs.
// Returns a cudaStream_t for use in kernel launches.
inline cudaStream_t get_current_cuda_stream(int32_t device_index) {
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
return reinterpret_cast<cudaStream_t>(stream_ptr);
}

View File

@@ -201,7 +201,6 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _zeros, int64_t split_k_iters,
int64_t thx, int64_t thy);
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
#endif
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,

View File

@@ -303,9 +303,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> Tensor");
// conditionally compiled so impl registration is in source file
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
// Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4).
ops.def(
"marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "

View File

@@ -597,6 +597,7 @@ class precompiled_wheel_utils:
with zipfile.ZipFile(wheel_path) as wheel:
files_to_copy = [
"vllm/_C.abi3.so",
"vllm/_C_stable_libtorch.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/_flashmla_C.abi3.so",
"vllm/_flashmla_extension_C.abi3.so",
@@ -932,6 +933,10 @@ if _is_cpu():
if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
# also _is_hip() once https://github.com/vllm-project/vllm/issues/35163 is
# fixed
if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm._C_stable_libtorch"))
package_data = {
"vllm": [

View File

@@ -19,6 +19,7 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm._C_stable_libtorch # noqa
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless