[Bugfix] Fix compute datatype for cutlass 3.x epilogues (#5931)

This commit is contained in:
Tyler Michael Smith
2024-06-28 13:10:34 -04:00
committed by GitHub
parent b2c620230a
commit 6a2d659d28
2 changed files with 70 additions and 59 deletions

View File

@@ -144,14 +144,14 @@ struct ScaledEpilogueBias
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, ElementD,
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, ElementD,
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using BiasDescriptor =