Signed-off-by: Lucia Fang <fanglu@fb.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com>
This commit is contained in:
@@ -8,6 +8,8 @@
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
@@ -95,9 +97,9 @@ struct cutlass_sparse_3x_gemm {
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
Reference in New Issue
Block a user