Add sm_100f support and make nvcc 13 happy (#157)

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
xiweny
2025-08-22 17:19:32 +08:00
committed by GitHub
parent f85ec649d7
commit affdb1cd90
4 changed files with 13 additions and 8 deletions

View File

@@ -155,7 +155,7 @@ public:
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);
// The override the compiler flags
flags = fmt::format("{} -I{} --gpu-architecture=sm_{}a "
flags = fmt::format("{} -I{} --gpu-architecture=sm_{} "
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
flags, library_include_path.c_str(), device_runtime->get_arch());
@@ -205,7 +205,7 @@ public:
}
// Override the compiler flags
flags = fmt::format("{} {}--gpu-architecture=sm_{}a -default-device {}",
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}",
flags, include_dirs, device_runtime->get_arch(), pch_flags);
}

View File

@@ -25,9 +25,12 @@ public:
return {prop->major, prop->minor};
}
int get_arch() {
std::string get_arch() {
const auto& [major, minor] = get_arch_pair();
return major * 10 + minor;
if (major == 10 && minor != 1) {
return "100f";
}
return std::to_string(major * 10 + minor) + "a";
}
int get_arch_major() {

View File

@@ -33,6 +33,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
using Allocator = std::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
// GEMM with accumulation must have FP32 output
if constexpr (kWithAccumulation)
@@ -169,7 +170,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
@@ -585,7 +586,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// NOTES: warp 0 is waiting TMA store
// TODO: do we need 2 SM allocation?
if (epilogue_warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
Allocator().free(0, kNumTmemCols);
}
// To safely deconstruct all barriers, we need a cluster sync

View File

@@ -32,6 +32,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
using Allocator = std::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
@@ -152,7 +153,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
@@ -520,7 +521,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// NOTES: warp 0 is waiting TMA store
// TODO: do we need 2 SM allocation?
if (epilogue_warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
Allocator().free(0, kNumTmemCols);
}
// To safely deconstruct all barriers, we need a cluster sync