From 0e49c3353b6658a522cee41c6933fa38032ef08e Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 27 Aug 2025 09:26:02 +0800 Subject: [PATCH] Refactor compiler version checks and arch flags --- csrc/jit/compiler.hpp | 23 ++++++++++------------- csrc/jit/device_runtime.hpp | 12 ++++++++---- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp index 0a6446f..55bcc60 100644 --- a/csrc/jit/compiler.hpp +++ b/csrc/jit/compiler.hpp @@ -140,8 +140,8 @@ class NVCCCompiler final: public Compiler { DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))"))); std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor); DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3"); - if (major < 12 or (major == 12 and minor < 9)) - printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance"); + if (major == 12 and minor < 9) + printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n"); return {major, minor}; } @@ -155,14 +155,12 @@ public: signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor); // The override the compiler flags - std::string selected_arch = device_runtime->get_arch(); - // Compatibility: NVCC < 12.9 may not recognize sm_100f; fallback to sm_100a - if (selected_arch == "100f" && (nvcc_major < 12 || (nvcc_major == 12 && nvcc_minor < 9))) - selected_arch = "100a"; + // Only NVCC >= 12.9 supports arch-specific family suffix + const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " "-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda", - flags, library_include_path.c_str(), selected_arch); + flags, library_include_path.c_str(), arch); } void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { @@ -193,6 +191,7 @@ public: int major, minor; DG_NVRTC_CHECK(nvrtcVersion(&major, &minor)); signature = fmt::format("NVRTC{}.{}", major, minor); + DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVRTC version should be >= 12.3"); // Build include directories list std::string include_dirs; @@ -202,19 +201,17 @@ public: // Add PCH support for version 12.8 and above // NOTES: PCH is vital for compilation speed std::string pch_flags; - if (major > 12 or (major == 12 and minor >= 8)) { + if (major > 12 or minor >= 8) { pch_flags = "--pch "; if (get_env("DG_JIT_DEBUG", 0)) pch_flags += "--pch-verbose=true "; } // Override the compiler flags - std::string selected_arch = device_runtime->get_arch(); - // Compatibility: NVRTC < 12.9 may not recognize sm_100f; fallback to sm_100a - if (selected_arch == "100f" && (major < 12 || (major == 12 && minor < 9))) - selected_arch = "100a"; + // Only NVRTC >= 12.9 supports arch-specific family suffix + const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9); flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}", - flags, include_dirs, selected_arch, pch_flags); + flags, include_dirs, arch, pch_flags); } void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index 310942d..79139d6 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -25,11 +25,15 @@ public: return {prop->major, prop->minor}; } - std::string get_arch() { + std::string get_arch(const bool& number_only = false, + const bool& support_arch_family = false) { const auto& [major, minor] = get_arch_pair(); - if (major == 10 and minor != 1) - return "100f"; - return std::to_string(major * 10 + minor) + "a"; + if (major == 10 and minor != 1) { + if (number_only) + return "100"; + return support_arch_family ? "100f" : "100a"; + } + return std::to_string(major * 10 + minor) + (number_only ? "" : "a"); } int get_arch_major() {