[Kernel] Update Cutlass fp8 configs (#5144)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
Varun Sundar Rabindranath
2024-06-01 14:16:07 +05:30
committed by GitHub
parent 260d119e86
commit f081c3ce4b
4 changed files with 480 additions and 15 deletions

View File

@@ -51,6 +51,11 @@ using namespace cute;
namespace {
uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
template <typename ElementAB_, typename ElementD_, typename TileShape,
typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
@@ -188,8 +193,89 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
CUTLASS_CHECK(status);
}
template <typename InType, typename OutType, int32_t M>
struct sm90_fp8_config {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
};
template <typename InType, typename OutType>
struct sm90_fp8_config<InType, OutType, 128> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
};
template <typename InType, typename OutType>
struct sm90_fp8_config<InType, OutType, 64> {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule>;
};
} // namespace
template <typename InType, typename OutType>
void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using Cutlass3xGemmDefault =
typename sm90_fp8_config<InType, OutType, 0>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config<InType, OutType, 64>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config<InType, OutType, 128>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
// m in [1, 64]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM64>(
out, a, b, a_scales, b_scales);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmM128>(
out, a, b, a_scales, b_scales);
} else {
// m in (128, inf)
return cutlass_scaled_mm_dq_dispatcher<Cutlass3xGemmDefault>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
@@ -223,24 +309,14 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule =
typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative;
using EpilogueSchedule =
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::bfloat16_t, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>>(
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_3x_gemm<cutlass::float_e4m3_t, cutlass::half_t, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>>(
return cutlass_scaled_mm_dq_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}