[Kernel] Update cutlass_scaled_mm to support 2d group (blockwise) scaling (#11868)
This commit is contained in:
@@ -272,6 +272,10 @@ struct MacheteCollectiveMma {
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
// One threads per CTA are producers (1 for operand tile)
|
||||
static constexpr int NumProducerThreadEvents = 1;
|
||||
|
||||
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
|
||||
shape<1>(SmemLayoutAtomScale{})));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user