Compatible with CUDA 13

This commit is contained in:
Chenggang Zhao
2025-08-22 17:29:10 +08:00
parent affdb1cd90
commit f20256fd50
4 changed files with 6 additions and 10 deletions

View File

@@ -27,9 +27,8 @@ public:
std::string get_arch() {
const auto& [major, minor] = get_arch_pair();
if (major == 10 && minor != 1) {
if (major == 10 and minor != 1)
return "100f";
}
return std::to_string(major * 10 + minor) + "a";
}

View File

@@ -32,6 +32,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
// GEMM with accumulation must have FP32 output
if constexpr (kWithAccumulation)
@@ -141,7 +142,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
@@ -472,15 +473,13 @@ sm100_bf16_gemm_impl(int* grouped_layout,
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
// TODO: do we actually need this?
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
// TODO: do we need 2 SM allocation?
if (epilogue_warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
Allocator().free(0, kNumTmemCols);
}
// To safely deconstruct all barriers, we need a cluster sync

View File

@@ -33,7 +33,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
using Allocator = std::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
// GEMM with accumulation must have FP32 output
if constexpr (kWithAccumulation)
@@ -578,13 +578,11 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
// TODO: do we actually need this?
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
// TODO: do we need 2 SM allocation?
if (epilogue_warp_idx == 1)
Allocator().free(0, kNumTmemCols);
}

View File

@@ -32,7 +32,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
using Allocator = std::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");