Signed-off-by: Lucia Fang <fanglu@fb.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com>
This commit is contained in:
@@ -15,15 +15,6 @@
|
|||||||
cutlassGetStatusString(error)); \
|
cutlassGetStatusString(error)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Panic wrapper for unwinding CUDA runtime errors
|
|
||||||
*/
|
|
||||||
#define CUDA_CHECK(status) \
|
|
||||||
{ \
|
|
||||||
cudaError_t error = status; \
|
|
||||||
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
|
|
||||||
}
|
|
||||||
|
|
||||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||||
int max_shared_mem_per_block_opt_in = 0;
|
int max_shared_mem_per_block_opt_in = 0;
|
||||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||||
|
|||||||
@@ -8,6 +8,8 @@
|
|||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "cuda_utils.h"
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
@@ -95,9 +97,9 @@ struct cutlass_sparse_3x_gemm {
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
using CollectiveMainloop =
|
using CollectiveMainloop =
|
||||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||||
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
||||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
||||||
ElementAcc, TileShape, ClusterShape,
|
ElementAcc, TileShape, ClusterShape,
|
||||||
Stages,
|
Stages,
|
||||||
KernelSchedule>::CollectiveOp;
|
KernelSchedule>::CollectiveOp;
|
||||||
|
|||||||
Reference in New Issue
Block a user