[Kernel] Update cutlass_scaled_mm to support 2d group (blockwise) scaling (#11868)

This commit is contained in:
Lucas Wilkinson
2025-01-30 21:33:00 -05:00
committed by GitHub
parent 4078052f09
commit 9798b2fb00
25 changed files with 1924 additions and 346 deletions

View File

@@ -272,6 +272,10 @@ struct MacheteCollectiveMma {
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
// One threads per CTA are producers (1 for operand tile)
static constexpr int NumProducerThreadEvents = 1;
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
shape<1>(SmemLayoutAtomScale{})));