From affdb1cd9024c5eceb55a1821c09f7067de023dd Mon Sep 17 00:00:00 2001 From: xiweny <13230610+VALLIS-NERIA@users.noreply.github.com> Date: Fri, 22 Aug 2025 17:19:32 +0800 Subject: [PATCH] Add sm_100f support and make nvcc 13 happy (#157) Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com> --- csrc/jit/compiler.hpp | 4 ++-- csrc/jit/device_runtime.hpp | 7 +++++-- deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 5 +++-- deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh | 5 +++-- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp index 46e92b6..09c3087 100644 --- a/csrc/jit/compiler.hpp +++ b/csrc/jit/compiler.hpp @@ -155,7 +155,7 @@ public: signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor); // The override the compiler flags - flags = fmt::format("{} -I{} --gpu-architecture=sm_{}a " + flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " "-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda", flags, library_include_path.c_str(), device_runtime->get_arch()); @@ -205,7 +205,7 @@ public: } // Override the compiler flags - flags = fmt::format("{} {}--gpu-architecture=sm_{}a -default-device {}", + flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}", flags, include_dirs, device_runtime->get_arch(), pch_flags); } diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index 7cd1882..5c14597 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -25,9 +25,12 @@ public: return {prop->major, prop->minor}; } - int get_arch() { + std::string get_arch() { const auto& [major, minor] = get_arch_pair(); - return major * 10 + minor; + if (major == 10 && minor != 1) { + return "100f"; + } + return std::to_string(major * 10 + minor) + "a"; } int get_arch_major() { diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 85c01ab..850ff9f 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -33,6 +33,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, const __grid_constant__ cute::TmaDescriptor tensor_map_d) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = std::conditional_t; // GEMM with accumulation must have FP32 output if constexpr (kWithAccumulation) @@ -169,7 +170,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cutlass::arch::fence_barrier_init(); } else if (threadIdx.x >= 32 and threadIdx.x < 64) { // Allocate tensor memory - cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); } kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); @@ -585,7 +586,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // NOTES: warp 0 is waiting TMA store // TODO: do we need 2 SM allocation? if (epilogue_warp_idx == 1) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + Allocator().free(0, kNumTmemCols); } // To safely deconstruct all barriers, we need a cluster sync diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh index 88b6b50..e0212b0 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh @@ -32,6 +32,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = std::conditional_t; // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); @@ -152,7 +153,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, cutlass::arch::fence_barrier_init(); } else if (threadIdx.x >= 32 and threadIdx.x < 64) { // Allocate tensor memory - cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); } kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); @@ -520,7 +521,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // NOTES: warp 0 is waiting TMA store // TODO: do we need 2 SM allocation? if (epilogue_warp_idx == 1) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + Allocator().free(0, kNumTmemCols); } // To safely deconstruct all barriers, we need a cluster sync