[Kernel][Quantization][MoE] add marlin kernel support for turing (sm75) (#29901)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -1,17 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#ifndef _marlin_cuh
|
||||
#define _marlin_cuh
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
@@ -51,9 +53,51 @@ using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int32_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int32_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int64_t*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int64_t*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
if (pred) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
reinterpret_cast<int4*>(smem_ptr)[0] =
|
||||
reinterpret_cast<const int4*>(glob_ptr)[0];
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() {}
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {}
|
||||
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
@@ -126,6 +170,8 @@ __device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user