diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index 5c14597..310942d 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -27,9 +27,8 @@ public: std::string get_arch() { const auto& [major, minor] = get_arch_pair(); - if (major == 10 && minor != 1) { + if (major == 10 and minor != 1) return "100f"; - } return std::to_string(major * 10 + minor) + "a"; } diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 789e220..46a668d 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -32,6 +32,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, const __grid_constant__ cute::TmaDescriptor tensor_map_d) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; // GEMM with accumulation must have FP32 output if constexpr (kWithAccumulation) @@ -141,7 +142,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, cutlass::arch::fence_barrier_init(); } else if (threadIdx.x >= 32 and threadIdx.x < 64) { // Allocate tensor memory - cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); } kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); @@ -472,15 +473,13 @@ sm100_bf16_gemm_impl(int* grouped_layout, } // Flush all stages in the pipeline to make TMA stores visible to the next kernel - // TODO: do we actually need this? if (epilogue_thread_idx == 0) cute::tma_store_wait<0>(); // Deallocate tensor memory by warp 1 // NOTES: warp 0 is waiting TMA store - // TODO: do we need 2 SM allocation? if (epilogue_warp_idx == 1) - cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + Allocator().free(0, kNumTmemCols); } // To safely deconstruct all barriers, we need a cluster sync diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 850ff9f..03c44cd 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -33,7 +33,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, const __grid_constant__ cute::TmaDescriptor tensor_map_d) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; - using Allocator = std::conditional_t; + using Allocator = cute::conditional_t; // GEMM with accumulation must have FP32 output if constexpr (kWithAccumulation) @@ -578,13 +578,11 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, } // Flush all stages in the pipeline to make TMA stores visible to the next kernel - // TODO: do we actually need this? if (epilogue_thread_idx == 0) cute::tma_store_wait<0>(); // Deallocate tensor memory by warp 1 // NOTES: warp 0 is waiting TMA store - // TODO: do we need 2 SM allocation? if (epilogue_warp_idx == 1) Allocator().free(0, kNumTmemCols); } diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh index e0212b0..455e600 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh @@ -32,7 +32,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; - using Allocator = std::conditional_t; + using Allocator = cute::conditional_t; // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");