[Kernel] Update CUTLASS to 3.5.1 (#7085)

This commit is contained in:
Tyler Michael Smith
2024-08-05 15:13:43 -04:00
committed by GitHub
parent 997cf78308
commit 8571ac4672
4 changed files with 127 additions and 105 deletions

View File

@@ -18,8 +18,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
@@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleBDescriptor =
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
EpilogueDescriptor, float>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
Stride<Int<0>, Int<1>, Int<0>>>;
};
/*
@@ -154,12 +148,8 @@ struct ScaledEpilogueBias
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using BiasDescriptor =
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
EpilogueDescriptor, ElementD>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD,
0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
public:
@@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
using StrideA = Stride<int64_t, Int<1>, int64_t>;
using StrideB = Stride<int64_t, Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC;
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
StrideA a_stride{lda, Int<1>{}, 0};
StrideB b_stride{ldb, Int<1>{}, 0};
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
using GemmKernel = typename Gemm::GemmKernel;
@@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}