[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)
This commit is contained in:
@@ -32,12 +32,15 @@
|
||||
|
||||
#else
|
||||
|
||||
#include "common/mem.h"
|
||||
#include "common/mma.h"
|
||||
#include "common/mem.h"
|
||||
#include "common/mma.h"
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T> inline std::string str(T x) { return std::to_string(x); }
|
||||
template <typename T>
|
||||
inline std::string str(T x) {
|
||||
return std::to_string(x);
|
||||
}
|
||||
|
||||
namespace marlin_24 {
|
||||
|
||||
@@ -45,7 +48,7 @@ namespace marlin_24 {
|
||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||
// we want relatively few warps to have many registers per warp and small tiles.
|
||||
static constexpr int THREADS = 256;
|
||||
static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory
|
||||
static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 128;
|
||||
|
||||
@@ -54,35 +57,36 @@ static constexpr int max_par = 16;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
template <const int num_bits, // weight bits
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
||||
// a separate quantization scale
|
||||
template <const int num_bits, // weight bits
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__global__ void Marlin_24(
|
||||
const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
const int4
|
||||
*__restrict__ meta, // 2bit metadata information about 2:4 format on B
|
||||
int4 *__restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4
|
||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int *locks // extra global storage for barrier synchronization
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
||||
// format on B
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks // extra global storage for barrier synchronization
|
||||
) {}
|
||||
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
torch::Tensor &b_meta,
|
||||
torch::Tensor &b_scales,
|
||||
torch::Tensor &workspace, int64_t num_bits,
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
@@ -92,29 +96,30 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
|
||||
#else
|
||||
|
||||
template <const int num_bits, // weight bits
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks with
|
||||
// a separate quantization scale
|
||||
template <const int num_bits, // weight bits
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__global__ void Marlin_24(
|
||||
const int4 *__restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
const int4
|
||||
*__restrict__ meta, // 2bit metadata information about 2:4 format on B
|
||||
int4 *__restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4
|
||||
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int *locks // extra global storage for barrier synchronization
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
const int4* __restrict__ meta, // 2bit metadata information about 2:4
|
||||
// format on B
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
const int4* __restrict__ s, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks // extra global storage for barrier synchronization
|
||||
) {
|
||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||
// same size, which might involve multiple column "slices" (of width 16 *
|
||||
@@ -174,27 +179,22 @@ __global__ void Marlin_24(
|
||||
auto init_slice = [&]() {
|
||||
slice_iters =
|
||||
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
|
||||
slice_iters = 0;
|
||||
if (slice_iters == 0)
|
||||
return;
|
||||
if (slice_row + slice_iters > k_tiles)
|
||||
slice_iters = k_tiles - slice_row;
|
||||
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
||||
if (slice_iters == 0) return;
|
||||
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
||||
slice_count = 1;
|
||||
slice_idx = 0;
|
||||
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
|
||||
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
||||
int col_off = col_first - k_tiles * slice_col_par;
|
||||
slice_count = ceildiv(k_tiles - col_off, iters);
|
||||
if (col_off > 0)
|
||||
slice_count++;
|
||||
if (col_off > 0) slice_count++;
|
||||
int delta_first = iters * blockIdx.x - col_first;
|
||||
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
||||
slice_idx = slice_count - 1;
|
||||
else {
|
||||
slice_idx = slice_count - 1 - delta_first / iters;
|
||||
if (col_off > 0)
|
||||
slice_idx--;
|
||||
if (col_off > 0) slice_idx--;
|
||||
}
|
||||
}
|
||||
if (slice_col == n_tiles) {
|
||||
@@ -207,7 +207,7 @@ __global__ void Marlin_24(
|
||||
init_slice();
|
||||
|
||||
// RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
|
||||
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
|
||||
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
|
||||
|
||||
// stride of an A matrix tile in shared memory
|
||||
constexpr int a_sh_stride = 32 * thread_k_blocks / 8;
|
||||
@@ -239,9 +239,9 @@ __global__ void Marlin_24(
|
||||
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
||||
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
||||
|
||||
int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16
|
||||
int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16
|
||||
constexpr int m_sh_stride =
|
||||
(16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp
|
||||
(16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp
|
||||
int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks;
|
||||
int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride);
|
||||
constexpr int m_sh_wr_delta = threads / 2;
|
||||
@@ -305,7 +305,7 @@ __global__ void Marlin_24(
|
||||
// needed if there are more threads than required for a certain tilesize or
|
||||
// when the batchsize is not a multiple of 16.
|
||||
bool a_sh_wr_pred[a_sh_wr_iters];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < a_sh_wr_iters; i++) {
|
||||
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
||||
}
|
||||
@@ -325,13 +325,13 @@ __global__ void Marlin_24(
|
||||
// loop unrolls, all shared memory accesses are static, we simply precompute
|
||||
// both transformed reads and writes.
|
||||
int a_sh_wr_trans[a_sh_wr_iters];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < a_sh_wr_iters; i++)
|
||||
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
||||
int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < thread_m_blocks; j++) {
|
||||
a_sh_rd_trans[0][i][j] =
|
||||
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
||||
@@ -344,23 +344,23 @@ __global__ void Marlin_24(
|
||||
// runtime; we break dependencies between subsequent accesses with a tile by
|
||||
// maintining multiple pointers (we have enough registers), a tiny
|
||||
// optimization.
|
||||
const int4 *B_ptr[b_sh_wr_iters];
|
||||
#pragma unroll
|
||||
const int4* B_ptr[b_sh_wr_iters];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
||||
|
||||
bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta;
|
||||
const int4 *meta_ptr[m_sh_iters];
|
||||
#pragma unroll
|
||||
const int4* meta_ptr[m_sh_iters];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < m_sh_iters; i++)
|
||||
meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd;
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
// Shared memory storage for global fetch pipelines.
|
||||
int4 *sh_a = sh;
|
||||
int4 *sh_b = sh_a + (stages * a_sh_stage);
|
||||
int4 *sh_s = sh_b + (stages * b_sh_stage);
|
||||
int4 *sh_m = sh_s + (stages * s_sh_stage);
|
||||
int4* sh_a = sh;
|
||||
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||
int4* sh_s = sh_b + (stages * b_sh_stage);
|
||||
int4* sh_m = sh_s + (stages * s_sh_stage);
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
FragA frag_a[2][thread_m_blocks][2];
|
||||
I4 frag_b_quant[2][b_thread_vecs];
|
||||
@@ -370,46 +370,43 @@ __global__ void Marlin_24(
|
||||
|
||||
// Zero accumulators.
|
||||
auto zero_accums = [&]() {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
||||
reinterpret_cast<float *>(frag_c)[i] = 0;
|
||||
reinterpret_cast<float*>(frag_c)[i] = 0;
|
||||
};
|
||||
|
||||
// Asynchronously fetch the next A, B and s tile from global to the next
|
||||
// shared memory pipeline location.
|
||||
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
||||
if (pred) {
|
||||
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < a_sh_wr_iters; i++) {
|
||||
cp_async4_pred(
|
||||
&sh_a_stage[a_sh_wr_trans[i]],
|
||||
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
||||
a_sh_wr_pred[i]);
|
||||
}
|
||||
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < b_thread_vecs; j++) {
|
||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j],
|
||||
B_ptr[i] + j);
|
||||
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
||||
}
|
||||
B_ptr[i] += b_gl_rd_delta_o;
|
||||
}
|
||||
int4 *sh_meta_stage = sh_m + m_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
int4* sh_meta_stage = sh_m + m_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < m_sh_iters; i++) {
|
||||
if (m_sh_wr_pred)
|
||||
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr],
|
||||
meta_ptr[i]);
|
||||
cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
|
||||
meta_ptr[i] += m_gl_rd_delta_o;
|
||||
}
|
||||
// Only fetch scales if this tile starts a new group
|
||||
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
if (s_sh_wr_pred)
|
||||
cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
|
||||
s_gl_rd += s_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
@@ -436,13 +433,13 @@ __global__ void Marlin_24(
|
||||
// theoretically better attempts have lead to bad instruction ordering by
|
||||
// the compiler and correspondingly a noticeable drop in performance.
|
||||
if (group_blocks != -1) {
|
||||
int4 *sh_s_stage =
|
||||
int4* sh_s_stage =
|
||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
reinterpret_cast<int4 *>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||
}
|
||||
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
ldsm4(frag_a[k % 2][i][0],
|
||||
&sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]);
|
||||
@@ -450,24 +447,24 @@ __global__ void Marlin_24(
|
||||
&sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]);
|
||||
}
|
||||
|
||||
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_thread_vecs; i++) {
|
||||
frag_b_quant[k % 2][i] = *reinterpret_cast<I4 *>(
|
||||
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
|
||||
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
||||
}
|
||||
|
||||
// Load meta with ldsm4
|
||||
int4 *sh_m_stage = sh_m + m_sh_stage * pipe;
|
||||
int4* sh_m_stage = sh_m + m_sh_stage * pipe;
|
||||
ldsm4_m(frag_m[k % 2][0],
|
||||
&sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]);
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
auto matmul = [&](int k) {
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
FragB frag_b0;
|
||||
FragB frag_b1;
|
||||
@@ -480,7 +477,7 @@ __global__ void Marlin_24(
|
||||
frag_b1 = dequant_4bit(b_quant_shift);
|
||||
|
||||
} else {
|
||||
int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
|
||||
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
||||
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||
|
||||
@@ -497,7 +494,7 @@ __global__ void Marlin_24(
|
||||
scale(frag_b1, frag_s[k % 2][j], 1);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0],
|
||||
frag_m[k % 2][j / 2], j % 2);
|
||||
@@ -518,41 +515,41 @@ __global__ void Marlin_24(
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||
(threadIdx.x % b_sh_stride_threads);
|
||||
|
||||
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
||||
// unnecessary read or write iterations, e.g., for two warps we write only once
|
||||
// by warp 1 and read only once by warp 0.
|
||||
#pragma unroll
|
||||
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
||||
// unnecessary read or write iterations, e.g., for two warps we write only
|
||||
// once by warp 1 and read only once by warp 0.
|
||||
#pragma unroll
|
||||
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = red_off; i > 0; i /= 2) {
|
||||
if (i <= red_idx && red_idx < 2 * i) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4 * 2; j++) {
|
||||
int red_sh_wr =
|
||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||
if (i < red_off) {
|
||||
float *c_rd = reinterpret_cast<float *>(
|
||||
&sh[red_sh_delta * j + red_sh_rd]);
|
||||
float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]);
|
||||
#pragma unroll
|
||||
float* c_rd =
|
||||
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 4; k++)
|
||||
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + j][k] +=
|
||||
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
||||
c_rd[k] + c_wr[k];
|
||||
}
|
||||
sh[red_sh_wr] =
|
||||
reinterpret_cast<int4 *>(&frag_c)[4 * 2 * m_block + j];
|
||||
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (red_idx == 0) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4 * 2; i++) {
|
||||
float *c_rd =
|
||||
reinterpret_cast<float *>(&sh[red_sh_delta * i + red_sh_rd]);
|
||||
#pragma unroll
|
||||
float* c_rd =
|
||||
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + i][j] +=
|
||||
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
||||
c_rd[j];
|
||||
}
|
||||
}
|
||||
@@ -562,9 +559,9 @@ __global__ void Marlin_24(
|
||||
};
|
||||
|
||||
// Since multiple threadblocks may process parts of the same column slice, we
|
||||
// finally have to globally reduce over the results. As the striped partitioning
|
||||
// minimizes the number of such reductions and our outputs are usually rather
|
||||
// small, we perform this reduction serially in L2 cache.
|
||||
// finally have to globally reduce over the results. As the striped
|
||||
// partitioning minimizes the number of such reductions and our outputs are
|
||||
// usually rather small, we perform this reduction serially in L2 cache.
|
||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||
// We are very careful here to reduce directly in the output buffer to
|
||||
// maximize L2 cache utilization in this step. To do this, we write out
|
||||
@@ -574,7 +571,7 @@ __global__ void Marlin_24(
|
||||
int c_gl_stride = prob_n / 8;
|
||||
int c_gl_wr_delta_o = 2 * 4 * c_gl_stride;
|
||||
int c_gl_wr_delta_i =
|
||||
c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28)
|
||||
c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28)
|
||||
int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) +
|
||||
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
|
||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||
@@ -584,10 +581,10 @@ __global__ void Marlin_24(
|
||||
int col = 2 * ((threadIdx.x % 32) % 4);
|
||||
|
||||
if (!first) {
|
||||
// Interestingly, doing direct global accesses here really seems to mess up the
|
||||
// compiler and lead to slowdowns, hence we also use async-copies even though
|
||||
// these fetches are not actually asynchronous.
|
||||
#pragma unroll
|
||||
// Interestingly, doing direct global accesses here really seems to mess up
|
||||
// the compiler and lead to slowdowns, hence we also use async-copies even
|
||||
// though these fetches are not actually asynchronous.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
||||
@@ -599,32 +596,32 @@ __global__ void Marlin_24(
|
||||
cp_async_wait<0>();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||
if (i < (thread_m_blocks - 1) * 4 ||
|
||||
8 * (i / 2) + col + (i % 2) < prob_m) {
|
||||
if (!first) {
|
||||
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j2 = 0; j2 < 2; j2++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < 4; j1++) {
|
||||
reinterpret_cast<float *>(
|
||||
reinterpret_cast<float*>(
|
||||
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
|
||||
4 * ((i % 4) / 2) + i % 2] +=
|
||||
__half2float(
|
||||
reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]);
|
||||
reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!last) {
|
||||
int4 c;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j2 = 0; j2 < 2; j2++) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < 4; j1++) {
|
||||
reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] =
|
||||
__float2half(reinterpret_cast<float *>(
|
||||
reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] =
|
||||
__float2half(reinterpret_cast<float*>(
|
||||
&frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
|
||||
4 * ((i % 4) / 2) + i % 2]);
|
||||
}
|
||||
@@ -643,9 +640,9 @@ __global__ void Marlin_24(
|
||||
auto write_result = [&]() {
|
||||
int c_gl_stride = prob_n / 8;
|
||||
|
||||
constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC:
|
||||
constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC:
|
||||
constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC:
|
||||
constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC:
|
||||
constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC:
|
||||
constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC:
|
||||
|
||||
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
||||
|
||||
@@ -654,22 +651,22 @@ __global__ void Marlin_24(
|
||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||
|
||||
int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) +
|
||||
((threadIdx.x % 32) / 4); // RLC:
|
||||
c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4)
|
||||
((threadIdx.x % 32) / 4); // RLC:
|
||||
c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4)
|
||||
|
||||
constexpr int c_sh_rd_delta =
|
||||
c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC:
|
||||
c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC:
|
||||
int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) +
|
||||
(threadIdx.x % (2 * 2 * thread_n_blocks));
|
||||
|
||||
int c_gl_wr_end = c_gl_stride * prob_m;
|
||||
|
||||
auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0,
|
||||
float c4, float c5, float c6, float c7, FragS &s1) {
|
||||
auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0,
|
||||
float c4, float c5, float c6, float c7, FragS& s1) {
|
||||
uint2 res[2];
|
||||
res[0] = to_half4(c0, c1, c2, c3);
|
||||
res[1] = to_half4(c4, c5, c6, c7);
|
||||
half2 *tmp = (half2 *)&res;
|
||||
half2* tmp = (half2*)&res;
|
||||
// for per-column quantization we finally apply the scale here
|
||||
if constexpr (group_blocks == -1 && num_bits == 4) {
|
||||
tmp[0] = __hmul2(tmp[0], s0[0]);
|
||||
@@ -677,12 +674,12 @@ __global__ void Marlin_24(
|
||||
tmp[2] = __hmul2(tmp[2], s1[0]);
|
||||
tmp[3] = __hmul2(tmp[3], s1[1]);
|
||||
}
|
||||
((int4 *)sh)[idx] = *((int4 *)&res[0]);
|
||||
((int4*)sh)[idx] = *((int4*)&res[0]);
|
||||
};
|
||||
|
||||
// RLC: only warp 0 and 1 baseline example
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
int wr = c_sh_wr;
|
||||
write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0],
|
||||
@@ -707,7 +704,7 @@ __global__ void Marlin_24(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0;
|
||||
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
||||
i++) {
|
||||
@@ -721,9 +718,8 @@ __global__ void Marlin_24(
|
||||
|
||||
// Start global fetch and register load pipelines.
|
||||
auto start_pipes = [&]() {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < stages - 1; i++)
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
|
||||
zero_accums();
|
||||
wait_for_stage();
|
||||
fetch_to_registers(0, 0);
|
||||
@@ -733,10 +729,10 @@ __global__ void Marlin_24(
|
||||
|
||||
// Main loop.
|
||||
while (slice_iters) {
|
||||
// We unroll over both the global fetch and the register load pipeline to ensure
|
||||
// all shared memory accesses are static. Note that both pipelines have even
|
||||
// length meaning that the next iteration will always start at index 0.
|
||||
#pragma unroll
|
||||
// We unroll over both the global fetch and the register load pipeline to
|
||||
// ensure all shared memory accesses are static. Note that both pipelines have
|
||||
// even length meaning that the next iteration will always start at index 0.
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < stages;) {
|
||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||
slice_iters >= stages);
|
||||
@@ -747,8 +743,7 @@ __global__ void Marlin_24(
|
||||
|
||||
pipe++;
|
||||
slice_iters--;
|
||||
if (slice_iters == 0)
|
||||
break;
|
||||
if (slice_iters == 0) break;
|
||||
}
|
||||
a_gl_rd += a_gl_rd_delta_o * stages;
|
||||
|
||||
@@ -762,13 +757,11 @@ __global__ void Marlin_24(
|
||||
// write-out
|
||||
if constexpr (group_blocks == -1) {
|
||||
if constexpr (num_bits == 8) {
|
||||
if (s_sh_wr_pred)
|
||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
cp_async_fence();
|
||||
} else {
|
||||
if (last) {
|
||||
if (s_sh_wr_pred)
|
||||
cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
|
||||
cp_async_fence();
|
||||
}
|
||||
}
|
||||
@@ -780,14 +773,14 @@ __global__ void Marlin_24(
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
*(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]);
|
||||
*(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
|
||||
}
|
||||
} else {
|
||||
if (last) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
*(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]);
|
||||
*(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -798,7 +791,7 @@ __global__ void Marlin_24(
|
||||
// overflow in fp16)
|
||||
if constexpr (group_blocks == -1 && num_bits == 8) {
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0],
|
||||
&frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0],
|
||||
@@ -827,13 +820,13 @@ __global__ void Marlin_24(
|
||||
}
|
||||
}
|
||||
|
||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||
// block in a slice
|
||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||
// block in a slice
|
||||
barrier_acquire(&locks[slice_col], slice_idx);
|
||||
global_reduce(slice_idx == 0, last);
|
||||
barrier_release(&locks[slice_col], last);
|
||||
}
|
||||
if (last) // only the last block in a slice actually writes the result
|
||||
if (last) // only the last block in a slice actually writes the result
|
||||
write_result();
|
||||
|
||||
slice_row = 0;
|
||||
@@ -843,19 +836,17 @@ __global__ void Marlin_24(
|
||||
if (slice_iters) {
|
||||
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||
(threadIdx.x % a_gl_rd_delta_o);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < m_sh_iters; i++)
|
||||
meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
|
||||
if (slice_col == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++)
|
||||
B_ptr[i] -= b_gl_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < m_sh_iters; i++)
|
||||
meta_ptr[i] -= m_gl_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
|
||||
}
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
start_pipes();
|
||||
@@ -866,26 +857,26 @@ __global__ void Marlin_24(
|
||||
|
||||
#endif
|
||||
|
||||
#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, GROUP_BLOCKS) \
|
||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
group_blocks == GROUP_BLOCKS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
||||
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
||||
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
|
||||
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
|
||||
C_ptr, s_ptr, prob_n, \
|
||||
prob_m, prob_k, locks); \
|
||||
#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, GROUP_BLOCKS) \
|
||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
group_blocks == GROUP_BLOCKS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
||||
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
|
||||
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
|
||||
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
|
||||
C_ptr, s_ptr, prob_n, \
|
||||
prob_m, prob_k, locks); \
|
||||
}
|
||||
|
||||
void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
void *s, int prob_m, int prob_n, int prob_k,
|
||||
void *workspace, int num_bits, int groupsize = -1,
|
||||
void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
||||
void* s, int prob_m, int prob_n, int prob_k,
|
||||
void* workspace, int num_bits, int groupsize = -1,
|
||||
int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
|
||||
int thread_m = -1, int sms = -1, int max_par = 16) {
|
||||
int tot_n = prob_n;
|
||||
@@ -904,8 +895,8 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
|
||||
if (thread_k == -1 || thread_m == -1) {
|
||||
if (prob_n <= 16) {
|
||||
// For small batchizes, better partitioningif is slightly more important than
|
||||
// better compute utilization
|
||||
// For small batchizes, better partitioningif is slightly more important
|
||||
// than better compute utilization
|
||||
thread_k = 128;
|
||||
thread_m = 128;
|
||||
} else {
|
||||
@@ -914,7 +905,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
}
|
||||
}
|
||||
|
||||
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
|
||||
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
|
||||
int thread_m_blocks = thread_m / 16;
|
||||
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
|
||||
int blocks = sms;
|
||||
@@ -931,13 +922,13 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
const int4 *A_ptr = (const int4 *)A;
|
||||
const int4 *B_ptr = (const int4 *)B;
|
||||
const int4 *meta_ptr = (const int4 *)meta;
|
||||
int4 *C_ptr = (int4 *)C;
|
||||
const int4 *s_ptr = (const int4 *)s;
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
const int4* meta_ptr = (const int4*)meta;
|
||||
int4* C_ptr = (int4*)C;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
|
||||
int *locks = (int *)workspace;
|
||||
int* locks = (int*)workspace;
|
||||
for (int i = 0; i < tot_n_blocks; i += 4) {
|
||||
int thread_n_blocks = tot_n_blocks - i;
|
||||
prob_n = tot_n - 16 * i;
|
||||
@@ -946,8 +937,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
// Note that parallel > 1 currently only works for inputs without any
|
||||
// padding
|
||||
par = (16 * thread_n_blocks - pad) / 64;
|
||||
if (par > max_par)
|
||||
par = max_par;
|
||||
if (par > max_par) par = max_par;
|
||||
prob_n = 64 * par;
|
||||
i += 4 * (par - 1);
|
||||
thread_n_blocks = 4;
|
||||
@@ -956,16 +946,16 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
// For compilation speed, we only define the kernel configurations that have
|
||||
// seemed useful (in terms of performance) in our testing, however many more
|
||||
// are, in principle, possible.
|
||||
|
||||
|
||||
// the false is start of the CALL_IF macros
|
||||
if (false) {
|
||||
} // BMxBNxBK, group
|
||||
if (false) {
|
||||
} // BMxBNxBK, group
|
||||
// 4-bit
|
||||
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
|
||||
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
||||
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
|
||||
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
||||
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
|
||||
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
|
||||
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
||||
CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
|
||||
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
||||
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
|
||||
CALL_IF_2_4(4, 16, 2, 2, 4)
|
||||
CALL_IF_2_4(4, 16, 3, 2, -1)
|
||||
CALL_IF_2_4(4, 16, 3, 2, 4)
|
||||
@@ -973,11 +963,11 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
CALL_IF_2_4(4, 16, 4, 2, 4)
|
||||
|
||||
// 8-bit
|
||||
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
|
||||
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
||||
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
|
||||
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
||||
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
|
||||
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
|
||||
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
|
||||
CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
|
||||
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
|
||||
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
|
||||
CALL_IF_2_4(8, 16, 2, 2, 4)
|
||||
CALL_IF_2_4(8, 16, 3, 2, -1)
|
||||
CALL_IF_2_4(8, 16, 3, 2, 4)
|
||||
@@ -997,12 +987,12 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin_24
|
||||
} // namespace marlin_24
|
||||
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
torch::Tensor &b_meta,
|
||||
torch::Tensor &b_scales,
|
||||
torch::Tensor &workspace, int64_t num_bits,
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k) {
|
||||
// Verify num_bits
|
||||
@@ -1037,9 +1027,9 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
" is not divisible by tile_size = " + str(marlin_24::tile_size));
|
||||
|
||||
int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n,
|
||||
"size_n = " + str(size_n) +
|
||||
", actual_size_n = " + str(actual_size_n));
|
||||
TORCH_CHECK(
|
||||
size_n == actual_size_n,
|
||||
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
|
||||
|
||||
// Verify meta
|
||||
TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
|
||||
@@ -1081,7 +1071,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
||||
", is not divisible by b_scales.size(0) = " +
|
||||
str(b_scales.size(0)));
|
||||
groupsize = size_k / b_scales.size(0);
|
||||
groupsize /= 2; // Because of 24
|
||||
groupsize /= 2; // Because of 24
|
||||
}
|
||||
|
||||
// Verify groupsize
|
||||
|
||||
Reference in New Issue
Block a user