[1/n] Migrate permute_cols to libtorch stable ABI (#31509)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
9
csrc/libtorch_stable/ops.h
Normal file
9
csrc/libtorch_stable/ops.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& perm);
|
||||
#endif
|
||||
@@ -1,10 +1,13 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "torch_utils.h"
|
||||
|
||||
static constexpr int default_threads = 256;
|
||||
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
@@ -64,19 +67,22 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
|
||||
// More efficient version of A[..., perm]
|
||||
// taken from gptq_marlin.cu
|
||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
auto dev = A.get_device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
torch::stable::Tensor permute_cols(torch::stable::Tensor const& A,
|
||||
torch::stable::Tensor const& perm) {
|
||||
const int32_t dev = A.get_device_index();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(dev);
|
||||
const auto stream = get_current_cuda_stream(dev);
|
||||
|
||||
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
|
||||
"Currently only 16bit types are supported");
|
||||
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
|
||||
TORCH_CHECK(A.size(-1) % 8 == 0,
|
||||
"A columns must be a multiple of 8 (128bits)");
|
||||
auto A_2d = A.view({-1, A.size(-1)});
|
||||
STD_TORCH_CHECK(
|
||||
A.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||
A.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||
"Currently only 16bit types are supported");
|
||||
STD_TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
|
||||
STD_TORCH_CHECK(A.size(-1) % 8 == 0,
|
||||
"A columns must be a multiple of 8 (128bits)");
|
||||
auto A_2d = torch::stable::view(A, {-1, A.size(-1)});
|
||||
|
||||
torch::Tensor D = torch::empty_like(A);
|
||||
torch::stable::Tensor D = torch::stable::empty_like(A);
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
||||
int block_rows = div_ceil(A_2d.size(0), sms);
|
||||
21
csrc/libtorch_stable/torch_bindings.cpp
Normal file
21
csrc/libtorch_stable/torch_bindings.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#include "ops.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
|
||||
// Register ops with STABLE_TORCH_LIBRARY for libtorch stable ABI compatibility.
|
||||
// Note: We register under namespace "_C" so ops are accessible as
|
||||
// torch.ops._C.<op_name> for compatibility with existing code.
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C, m) {
|
||||
#ifndef USE_ROCM
|
||||
m.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
#endif
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||
#ifndef USE_ROCM
|
||||
m.impl("permute_cols", TORCH_BOX(&permute_cols));
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_C_stable_libtorch)
|
||||
13
csrc/libtorch_stable/torch_utils.h
Normal file
13
csrc/libtorch_stable/torch_utils.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// Utility to get the current CUDA stream for a given device using stable APIs.
|
||||
// Returns a cudaStream_t for use in kernel launches.
|
||||
inline cudaStream_t get_current_cuda_stream(int32_t device_index) {
|
||||
void* stream_ptr = nullptr;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
|
||||
return reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
}
|
||||
@@ -201,7 +201,6 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||
torch::Tensor _zeros, int64_t split_k_iters,
|
||||
int64_t thx, int64_t thy);
|
||||
|
||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
||||
#endif
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||
|
||||
@@ -303,9 +303,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
") -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
||||
|
||||
// Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4).
|
||||
ops.def(
|
||||
"marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
|
||||
Reference in New Issue
Block a user