[Kernel] Update CUTLASS to 3.5.1 (#7085)
This commit is contained in:
committed by
GitHub
parent
997cf78308
commit
8571ac4672
@@ -18,8 +18,6 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
@@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleBDescriptor =
|
||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||
EpilogueDescriptor, float>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
|
||||
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -154,12 +148,8 @@ struct ScaledEpilogueBias
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using BiasDescriptor =
|
||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||
EpilogueDescriptor, ElementD>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD,
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
|
||||
|
||||
public:
|
||||
@@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
||||
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
@@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user