[feat]: add SM100 support for cutlass FP8 groupGEMM (#20447)
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -18,7 +18,6 @@ using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
@@ -33,7 +32,7 @@ using LayoutD_Transpose =
|
||||
using LayoutC = LayoutD;
|
||||
using LayoutC_Transpose = LayoutD_Transpose;
|
||||
|
||||
template <typename ElementAB_, typename ElementC_,
|
||||
template <typename ElementAB_, typename ElementC_, typename ArchTag_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule, bool swap_ab_ = false>
|
||||
@@ -43,6 +42,7 @@ struct cutlass_3x_group_gemm {
|
||||
using ElementC = void;
|
||||
using ElementD = ElementC_;
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = ArchTag_;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
||||
|
||||
@@ -77,7 +77,7 @@ struct cutlass_3x_group_gemm {
|
||||
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
||||
Stages, KernelSchedule>::CollectiveOp>;
|
||||
|
||||
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
@@ -156,9 +156,14 @@ void cutlass_group_gemm_caller(
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
int device_id = a_tensors.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
|
||||
epilogue_args};
|
||||
epilogue_args, hw_info};
|
||||
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
|
||||
Reference in New Issue
Block a user