diff --git a/CMakeLists.txt b/CMakeLists.txt index afc02f7fb..cf59f18eb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -363,7 +363,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # - sm80 doesn't support fp8 computation # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) - cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}") # marlin arches for other files cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") @@ -523,12 +523,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() - # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require + # The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require # CUDA 12.8 or later if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS @@ -616,12 +616,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require + # The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require # CUDA 12.8 or later if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") else() - cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS @@ -1050,7 +1050,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # - sm80 doesn't support fp8 computation # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) - cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}") # moe marlin arches for other files cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") if (MARLIN_MOE_OTHER_ARCHS) diff --git a/cmake/external_projects/qutlass.cmake b/cmake/external_projects/qutlass.cmake index 84bb1b00c..273fe754b 100644 --- a/cmake/external_projects/qutlass.cmake +++ b/cmake/external_projects/qutlass.cmake @@ -32,16 +32,16 @@ endif() message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0f" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(QUTLASS_ARCHS "10.0f;12.0f" "${CUDA_ARCHS}") else() - cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a;10.3a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;12.1a;10.0a;10.3a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND QUTLASS_ARCHS) if(QUTLASS_ARCHS MATCHES "10\\.(0a|3a|0f)") set(QUTLASS_TARGET_CC 100) - elseif(QUTLASS_ARCHS MATCHES "12\\.0a") + elseif(QUTLASS_ARCHS MATCHES "12\\.[01][af]?") set(QUTLASS_TARGET_CC 120) else() message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.") @@ -96,7 +96,7 @@ else() "[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).") else() message(STATUS - "[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in " + "[QUTLASS] Skipping build: no supported arch (12.0f / 10.0f) found in " "CUDA_ARCHS='${CUDA_ARCHS}'.") endif() endif() diff --git a/cmake/utils.cmake b/cmake/utils.cmake index fd3d7e0ae..e95333457 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -355,8 +355,11 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR list(REMOVE_DUPLICATES _PTX_ARCHS) list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) - # If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should - # remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS + # Handle architecture-specific suffixes (a/f) for SRC entries. + # First try exact base match (x.y), then cross-suffix match (x.ya / x.yf). + # For 'f' (family) suffix: if no exact/cross match, fall back to major-version + # match — e.g. SRC="12.0f" matches TGT="12.1a" since SM121 is in the SM12x + # family. The output uses TGT's value to preserve the user's compilation flags. set(_CUDA_ARCHS) foreach(_arch ${_SRC_CUDA_ARCHS}) if(_arch MATCHES "[af]$") @@ -365,6 +368,38 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR if ("${_base}" IN_LIST TGT_CUDA_ARCHS) list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") list(APPEND _CUDA_ARCHS "${_arch}") + elseif("${_base}a" IN_LIST _TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}a") + list(APPEND _CUDA_ARCHS "${_base}a") + elseif("${_base}f" IN_LIST _TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}f") + list(APPEND _CUDA_ARCHS "${_base}f") + elseif(_arch MATCHES "f$") + # Family suffix: match any TGT entry in the same major version family. + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _src_major "${_base}") + foreach(_tgt ${_TGT_CUDA_ARCHS}) + string(REGEX REPLACE "[af]$" "" _tgt_base "${_tgt}") + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _tgt_major "${_tgt_base}") + if(_tgt_major STREQUAL _src_major) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_tgt}") + list(APPEND _CUDA_ARCHS "${_tgt}") + break() + endif() + endforeach() + endif() + endif() + endforeach() + + # Symmetric handling: if TGT has x.ya/f and SRC has x.y (without suffix), + # preserve TGT's suffix in the output. + set(_tgt_copy ${_TGT_CUDA_ARCHS}) + foreach(_arch ${_tgt_copy}) + if(_arch MATCHES "[af]$") + string(REGEX REPLACE "[af]$" "" _base "${_arch}") + if ("${_base}" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_arch}") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_base}") + list(APPEND _CUDA_ARCHS "${_arch}") endif() endif() endforeach()