[NVIDIA] Fix DGX Spark logic (#38126)
Signed-off-by: johnnynunez <johnnynuca14@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com> Signed-off-by: Sathish Sanjeevi <sathish.krishnan.p.s@gmail.com> Signed-off-by: guillaume_guy <guillaume.guy@airbnb.com> Signed-off-by: Guillaume Guy <guillaume.c.guy@gmail.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Andreas Karatzas <akaratza@amd.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: Sathish Sanjeevi <SKPsanjeevi@users.noreply.github.com> Co-authored-by: Guillaume Guy <guillaume.c.guy@gmail.com> Co-authored-by: guillaume_guy <guillaume.guy@airbnb.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -32,16 +32,16 @@ endif()
|
||||
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0f" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(QUTLASS_ARCHS "10.0f;12.0f" "${CUDA_ARCHS}")
|
||||
else()
|
||||
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a;10.3a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;12.1a;10.0a;10.3a" "${CUDA_ARCHS}")
|
||||
endif()
|
||||
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND QUTLASS_ARCHS)
|
||||
|
||||
if(QUTLASS_ARCHS MATCHES "10\\.(0a|3a|0f)")
|
||||
set(QUTLASS_TARGET_CC 100)
|
||||
elseif(QUTLASS_ARCHS MATCHES "12\\.0a")
|
||||
elseif(QUTLASS_ARCHS MATCHES "12\\.[01][af]?")
|
||||
set(QUTLASS_TARGET_CC 120)
|
||||
else()
|
||||
message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.")
|
||||
@@ -96,7 +96,7 @@ else()
|
||||
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
|
||||
else()
|
||||
message(STATUS
|
||||
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in "
|
||||
"[QUTLASS] Skipping build: no supported arch (12.0f / 10.0f) found in "
|
||||
"CUDA_ARCHS='${CUDA_ARCHS}'.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -355,8 +355,11 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
||||
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
||||
|
||||
# If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||
# remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS
|
||||
# Handle architecture-specific suffixes (a/f) for SRC entries.
|
||||
# First try exact base match (x.y), then cross-suffix match (x.ya / x.yf).
|
||||
# For 'f' (family) suffix: if no exact/cross match, fall back to major-version
|
||||
# match — e.g. SRC="12.0f" matches TGT="12.1a" since SM121 is in the SM12x
|
||||
# family. The output uses TGT's value to preserve the user's compilation flags.
|
||||
set(_CUDA_ARCHS)
|
||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||
if(_arch MATCHES "[af]$")
|
||||
@@ -365,6 +368,38 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
||||
list(APPEND _CUDA_ARCHS "${_arch}")
|
||||
elseif("${_base}a" IN_LIST _TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}a")
|
||||
list(APPEND _CUDA_ARCHS "${_base}a")
|
||||
elseif("${_base}f" IN_LIST _TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}f")
|
||||
list(APPEND _CUDA_ARCHS "${_base}f")
|
||||
elseif(_arch MATCHES "f$")
|
||||
# Family suffix: match any TGT entry in the same major version family.
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _src_major "${_base}")
|
||||
foreach(_tgt ${_TGT_CUDA_ARCHS})
|
||||
string(REGEX REPLACE "[af]$" "" _tgt_base "${_tgt}")
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _tgt_major "${_tgt_base}")
|
||||
if(_tgt_major STREQUAL _src_major)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_tgt}")
|
||||
list(APPEND _CUDA_ARCHS "${_tgt}")
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Symmetric handling: if TGT has x.ya/f and SRC has x.y (without suffix),
|
||||
# preserve TGT's suffix in the output.
|
||||
set(_tgt_copy ${_TGT_CUDA_ARCHS})
|
||||
foreach(_arch ${_tgt_copy})
|
||||
if(_arch MATCHES "[af]$")
|
||||
string(REGEX REPLACE "[af]$" "" _base "${_arch}")
|
||||
if ("${_base}" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_arch}")
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_base}")
|
||||
list(APPEND _CUDA_ARCHS "${_arch}")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
Reference in New Issue
Block a user