[NVIDIA] Fix DGX Spark logic (#38126)

Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
Signed-off-by: Sathish Sanjeevi <sathish.krishnan.p.s@gmail.com>
Signed-off-by: guillaume_guy <guillaume.guy@airbnb.com>
Signed-off-by: Guillaume Guy <guillaume.c.guy@gmail.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Andreas Karatzas <akaratza@amd.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: Sathish Sanjeevi <SKPsanjeevi@users.noreply.github.com>
Co-authored-by: Guillaume Guy <guillaume.c.guy@gmail.com>
Co-authored-by: guillaume_guy <guillaume.guy@airbnb.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Johnny
2026-03-27 23:26:07 +01:00
committed by GitHub
parent 384e4d5f48
commit 97d19197bc
3 changed files with 47 additions and 12 deletions

View File

@@ -363,7 +363,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# - sm80 doesn't support fp8 computation # - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}")
# marlin arches for other files # marlin arches for other files
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
@@ -523,12 +523,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require # The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later # CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else() else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS set(SRCS
@@ -616,12 +616,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
endif() endif()
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require # The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later # CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else() else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS set(SRCS
@@ -1050,7 +1050,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# - sm80 doesn't support fp8 computation # - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}")
# moe marlin arches for other files # moe marlin arches for other files
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_OTHER_ARCHS) if (MARLIN_MOE_OTHER_ARCHS)

View File

@@ -32,16 +32,16 @@ endif()
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}") message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0f" "${CUDA_ARCHS}") cuda_archs_loose_intersection(QUTLASS_ARCHS "10.0f;12.0f" "${CUDA_ARCHS}")
else() else()
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a;10.3a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;12.1a;10.0a;10.3a" "${CUDA_ARCHS}")
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND QUTLASS_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND QUTLASS_ARCHS)
if(QUTLASS_ARCHS MATCHES "10\\.(0a|3a|0f)") if(QUTLASS_ARCHS MATCHES "10\\.(0a|3a|0f)")
set(QUTLASS_TARGET_CC 100) set(QUTLASS_TARGET_CC 100)
elseif(QUTLASS_ARCHS MATCHES "12\\.0a") elseif(QUTLASS_ARCHS MATCHES "12\\.[01][af]?")
set(QUTLASS_TARGET_CC 120) set(QUTLASS_TARGET_CC 120)
else() else()
message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.") message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.")
@@ -96,7 +96,7 @@ else()
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).") "[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
else() else()
message(STATUS message(STATUS
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in " "[QUTLASS] Skipping build: no supported arch (12.0f / 10.0f) found in "
"CUDA_ARCHS='${CUDA_ARCHS}'.") "CUDA_ARCHS='${CUDA_ARCHS}'.")
endif() endif()
endif() endif()

View File

@@ -355,8 +355,11 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
list(REMOVE_DUPLICATES _PTX_ARCHS) list(REMOVE_DUPLICATES _PTX_ARCHS)
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
# If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # Handle architecture-specific suffixes (a/f) for SRC entries.
# remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS # First try exact base match (x.y), then cross-suffix match (x.ya / x.yf).
# For 'f' (family) suffix: if no exact/cross match, fall back to major-version
# match — e.g. SRC="12.0f" matches TGT="12.1a" since SM121 is in the SM12x
# family. The output uses TGT's value to preserve the user's compilation flags.
set(_CUDA_ARCHS) set(_CUDA_ARCHS)
foreach(_arch ${_SRC_CUDA_ARCHS}) foreach(_arch ${_SRC_CUDA_ARCHS})
if(_arch MATCHES "[af]$") if(_arch MATCHES "[af]$")
@@ -365,6 +368,38 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
if ("${_base}" IN_LIST TGT_CUDA_ARCHS) if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
list(APPEND _CUDA_ARCHS "${_arch}") list(APPEND _CUDA_ARCHS "${_arch}")
elseif("${_base}a" IN_LIST _TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}a")
list(APPEND _CUDA_ARCHS "${_base}a")
elseif("${_base}f" IN_LIST _TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}f")
list(APPEND _CUDA_ARCHS "${_base}f")
elseif(_arch MATCHES "f$")
# Family suffix: match any TGT entry in the same major version family.
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _src_major "${_base}")
foreach(_tgt ${_TGT_CUDA_ARCHS})
string(REGEX REPLACE "[af]$" "" _tgt_base "${_tgt}")
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _tgt_major "${_tgt_base}")
if(_tgt_major STREQUAL _src_major)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_tgt}")
list(APPEND _CUDA_ARCHS "${_tgt}")
break()
endif()
endforeach()
endif()
endif()
endforeach()
# Symmetric handling: if TGT has x.ya/f and SRC has x.y (without suffix),
# preserve TGT's suffix in the output.
set(_tgt_copy ${_TGT_CUDA_ARCHS})
foreach(_arch ${_tgt_copy})
if(_arch MATCHES "[af]$")
string(REGEX REPLACE "[af]$" "" _base "${_arch}")
if ("${_base}" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_arch}")
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_base}")
list(APPEND _CUDA_ARCHS "${_arch}")
endif() endif()
endif() endif()
endforeach() endforeach()