[Kernel] Update CUTLASS to 3.5.1 (#7085)

This commit is contained in:
Tyler Michael Smith
2024-08-05 15:13:43 -04:00
committed by GitHub
parent 997cf78308
commit 8571ac4672
4 changed files with 127 additions and 105 deletions

View File

@@ -10,8 +10,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
@@ -301,12 +299,14 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.get(), stream);
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}