[Kernel][Quantization][MoE] add marlin kernel support for turing (sm75) (#29901)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin
2025-12-17 06:35:28 +08:00
committed by GitHub
parent eaa82a709a
commit ce96857fdd
16 changed files with 729 additions and 513 deletions

View File

@@ -1,17 +1,19 @@
#pragma once
#include <torch/all.h>
#ifndef _marlin_cuh
#define _marlin_cuh
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
namespace MARLIN_NAMESPACE_NAME {
@@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int32_t*>(smem_ptr)[0] =
reinterpret_cast<const int32_t*>(glob_ptr)[0];
}
}
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int64_t*>(smem_ptr)[0] =
reinterpret_cast<const int64_t*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
if (pred) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
}
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
reinterpret_cast<int4*>(smem_ptr)[0] =
reinterpret_cast<const int4*>(glob_ptr)[0];
}
__device__ inline void cp_async_fence() {}
template <int n>
__device__ inline void cp_async_wait() {}
#else
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
@@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
#endif
#endif
} // namespace MARLIN_NAMESPACE_NAME
#endif