diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 5283344..c000afa 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -367,6 +367,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Wait last TMA store to be finished if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");