[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)
Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||
return StrideEle{};
|
||||
} else {
|
||||
return tensor.stride(idx);
|
||||
if (tensor.size(idx) == 1) {
|
||||
// use 0 stride for dim with size 1, this is easier for
|
||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||
return StrideEle{0};
|
||||
} else {
|
||||
return tensor.stride(idx);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
|
||||
@@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,
|
||||
|
||||
}; // namespace machete
|
||||
|
||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
||||
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
|
||||
88
csrc/permute_cols.cu
Normal file
88
csrc/permute_cols.cu
Normal file
@@ -0,0 +1,88 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
static constexpr int default_threads = 256;
|
||||
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||
// on the given "perm" indices.
|
||||
// Currently only supports 16bit types (since we permute half types)
|
||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr, int size_m,
|
||||
int size_k, int block_rows) {
|
||||
int start_row = block_rows * blockIdx.x;
|
||||
int finish_row = start_row + block_rows;
|
||||
if (finish_row > size_m) {
|
||||
finish_row = size_m;
|
||||
}
|
||||
int cur_block_rows = std::max(finish_row - start_row, 0);
|
||||
|
||||
int row_stride = size_k * sizeof(half) / 16;
|
||||
|
||||
auto permute_row = [&](int row) {
|
||||
int iters = size_k / default_threads;
|
||||
int rest = size_k % default_threads;
|
||||
|
||||
int offset = row * row_stride;
|
||||
|
||||
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
|
||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
||||
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
|
||||
base_k += default_threads;
|
||||
}
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < cur_block_rows; i++) {
|
||||
int cur_row = start_row + i;
|
||||
if (cur_row < size_m) {
|
||||
permute_row(cur_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// More efficient version of A[..., perm]
|
||||
// taken from gptq_marlin.cu
|
||||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
auto dev = A.get_device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
|
||||
"Currently only 16bit types are supported");
|
||||
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
|
||||
TORCH_CHECK(A.size(-1) % 8 == 0,
|
||||
"A columns must be a multiple of 8 (128bits)");
|
||||
auto A_2d = A.view({-1, A.size(-1)});
|
||||
|
||||
torch::Tensor D = torch::empty_like(A);
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
||||
int block_rows = div_ceil(A_2d.size(0), sms);
|
||||
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
|
||||
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
|
||||
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
|
||||
A_2d.size(0), A_2d.size(1), block_rows);
|
||||
return D;
|
||||
}
|
||||
@@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleConfig:
|
||||
tile_shape_mn: Tuple[int, int]
|
||||
cluster_shape_mnk: Tuple[int, int, int]
|
||||
@@ -328,56 +328,137 @@ def generate():
|
||||
# about how this works
|
||||
SCRIPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
schedules = [
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=tile_shape_mn,
|
||||
cluster_shape_mnk=cluster_shape_mnk,
|
||||
kernel_schedule=kernel_schedule,
|
||||
epilogue_schedule=epilogue_schedule,
|
||||
tile_scheduler=tile_scheduler,
|
||||
) for tile_shape_mn, cluster_shape_mnk in (
|
||||
((128, 16), (1, 1, 1)),
|
||||
((128, 32), (1, 1, 1)),
|
||||
((128, 64), (1, 1, 1)),
|
||||
((128, 128), (1, 1, 1)),
|
||||
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
|
||||
for tile_scheduler in (TileSchedulerType.StreamK, )
|
||||
]
|
||||
schedule_common_params = dict(
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
default_heuristic = [
|
||||
("M > 64",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
("M > 32",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
("M > 16",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)),
|
||||
(None,
|
||||
ScheduleConfig(tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK))
|
||||
#### M = 257+
|
||||
(
|
||||
"M > 256 && K <= 16384 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 256",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 129-256
|
||||
(
|
||||
"M > 128 && K <= 4096 && N <= 4096",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128 && K <= 8192 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 128",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 256),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 65-128
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 4069",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K <= 4069 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64 && K >= 8192 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 64",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 128),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 33-64
|
||||
(
|
||||
"M > 32 && K <= 6144 && N <= 6144",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32 && K >= 16384 && N >= 12288",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 32",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 64),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 17-32
|
||||
(
|
||||
"M > 16 && K <= 12288 && N <= 8192",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
"M > 16",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 32),
|
||||
cluster_shape_mnk=(2, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
#### M = 1-16
|
||||
(
|
||||
"N >= 26624",
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(256, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
(
|
||||
None,
|
||||
ScheduleConfig(
|
||||
tile_shape_mn=(128, 16),
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
**schedule_common_params # type: ignore
|
||||
)),
|
||||
]
|
||||
|
||||
schedules = list(set([x[1] for x in default_heuristic]))
|
||||
|
||||
impl_configs = []
|
||||
|
||||
GPTQ_kernel_type_configs = list(
|
||||
|
||||
@@ -152,7 +152,8 @@ struct MacheteKernelTemplate {
|
||||
|
||||
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
|
||||
|
||||
int const group_size = maybe_group_size.value_or(K);
|
||||
int const group_size =
|
||||
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
||||
int const scale_k = (K + group_size - 1) / group_size;
|
||||
|
||||
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
||||
|
||||
@@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) {
|
||||
auto arguments = MacheteKernel::create_arguments(
|
||||
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
|
||||
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
|
||||
args.group_size.value_or(K));
|
||||
args.group_size);
|
||||
TORCH_CHECK(MacheteKernel::can_implement(arguments),
|
||||
"Machete kernel cannot be run with these arguments");
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
|
||||
// clang-format on
|
||||
|
||||
// Allocate output
|
||||
torch::Tensor D = torch::empty_like(B);
|
||||
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
|
||||
|
||||
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
|
||||
static_cast<ElementB*>(D.mutable_data_ptr()));
|
||||
|
||||
@@ -192,6 +192,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"-> Tensor");
|
||||
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
|
||||
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
||||
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
|
||||
Reference in New Issue
Block a user