[Bugfix] Preserve CUDA arch suffix (a/f) for SM12x — fixes NVFP4 NaN on desktop Blackwell (#37725)
Signed-off-by: Rob Tand <robert.tand@icloud.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -94,10 +94,10 @@ find_package(Torch REQUIRED)
|
||||
# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined
|
||||
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
|
||||
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
|
||||
set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0")
|
||||
set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0;12.1")
|
||||
elseif(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
|
||||
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0;12.1")
|
||||
else()
|
||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
|
||||
endif()
|
||||
|
||||
@@ -173,8 +173,10 @@ print(candidates[0] if candidates else '')
|
||||
endfunction()
|
||||
|
||||
# Macro for converting a `gencode` version number to a cmake version number.
|
||||
# Preserves architecture-specific suffixes (a/f) needed for correct
|
||||
# __CUDA_ARCH_FAMILY_SPECIFIC__ definition. E.g. "121a" -> "12.1a".
|
||||
macro(string_to_ver OUT_VER IN_STR)
|
||||
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
||||
string(REGEX REPLACE "\([0-9]+\)\([0-9][af]?\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
||||
endmacro()
|
||||
|
||||
#
|
||||
@@ -211,7 +213,7 @@ endmacro()
|
||||
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
|
||||
set(_CUDA_ARCHES)
|
||||
foreach(_ARCH ${CUDA_ARCH_FLAGS})
|
||||
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
||||
string(REGEX MATCH "arch=compute_\([0-9]+[af]?\)" _COMPUTE ${_ARCH})
|
||||
if (_COMPUTE)
|
||||
set(_COMPUTE ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
Reference in New Issue
Block a user