Add sm_100f support and make nvcc 13 happy (#157)
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
@@ -155,7 +155,7 @@ public:
|
||||
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);
|
||||
|
||||
// The override the compiler flags
|
||||
flags = fmt::format("{} -I{} --gpu-architecture=sm_{}a "
|
||||
flags = fmt::format("{} -I{} --gpu-architecture=sm_{} "
|
||||
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
|
||||
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
|
||||
flags, library_include_path.c_str(), device_runtime->get_arch());
|
||||
@@ -205,7 +205,7 @@ public:
|
||||
}
|
||||
|
||||
// Override the compiler flags
|
||||
flags = fmt::format("{} {}--gpu-architecture=sm_{}a -default-device {}",
|
||||
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}",
|
||||
flags, include_dirs, device_runtime->get_arch(), pch_flags);
|
||||
}
|
||||
|
||||
|
||||
@@ -25,9 +25,12 @@ public:
|
||||
return {prop->major, prop->minor};
|
||||
}
|
||||
|
||||
int get_arch() {
|
||||
std::string get_arch() {
|
||||
const auto& [major, minor] = get_arch_pair();
|
||||
return major * 10 + minor;
|
||||
if (major == 10 && minor != 1) {
|
||||
return "100f";
|
||||
}
|
||||
return std::to_string(major * 10 + minor) + "a";
|
||||
}
|
||||
|
||||
int get_arch_major() {
|
||||
|
||||
@@ -33,6 +33,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
using Allocator = std::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
|
||||
|
||||
// GEMM with accumulation must have FP32 output
|
||||
if constexpr (kWithAccumulation)
|
||||
@@ -169,7 +170,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
|
||||
// Allocate tensor memory
|
||||
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
@@ -585,7 +586,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
// TODO: do we need 2 SM allocation?
|
||||
if (epilogue_warp_idx == 1)
|
||||
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
||||
Allocator().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// To safely deconstruct all barriers, we need a cluster sync
|
||||
|
||||
@@ -32,6 +32,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
using Allocator = std::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
|
||||
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
@@ -152,7 +153,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
|
||||
// Allocate tensor memory
|
||||
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
@@ -520,7 +521,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
// TODO: do we need 2 SM allocation?
|
||||
if (epilogue_warp_idx == 1)
|
||||
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
||||
Allocator().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// To safely deconstruct all barriers, we need a cluster sync
|
||||
|
||||
Reference in New Issue
Block a user