[Kernel] Update CUTLASS to 3.5.1 (#7085)
This commit is contained in:
committed by
GitHub
parent
997cf78308
commit
8571ac4672
@@ -10,8 +10,6 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
@@ -301,12 +299,14 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.get(), stream);
|
||||
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user