diff --git a/CMakeLists.txt b/CMakeLists.txt index ff57473cf..fc50b896d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -340,7 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp") @@ -490,132 +489,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - - set(SCALED_MM_3X_ARCHS) - # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require - # CUDA 12.0 or later - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) - set(SRCS - "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") - # Let scaled_mm_c2x know it doesn't need to build these arches - list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") - message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) - message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " - "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 quantized models on " - "Hopper.") - else() - message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - - # The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require - # CUDA 12.8 or later - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS - "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" - ) - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") - # Let scaled_mm_c2x know it doesn't need to build these arches - list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") - message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or " - "later if you intend on running FP8 quantized models on " - "Blackwell.") - else() - message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - - # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) - # require CUDA 12.8 or later - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS - "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" - ) - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1") - # Let scaled_mm_c2x know it doesn't need to build these arches - list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") - message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or " - "later if you intend on running FP8 quantized models on " - "Blackwell.") - else() - message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - # - # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) - # kernels for the remaining archs that are not already built for 3x. - # (Build 8.9 for FP8) - cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}") - # subtract out the archs that are already built for 3x - list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) - if (SCALED_MM_2X_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") - message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") - else() - if (SCALED_MM_3X_ARCHS) - message(STATUS "Not building scaled_mm_c2x as all archs are already built" - " for and covered by scaled_mm_c3x") - else() - message(STATUS "Not building scaled_mm_c2x as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - # The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require # CUDA 12.8 or later if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) @@ -693,55 +566,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(MLA_ARCHS) endif() - # CUTLASS MoE kernels - - # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works - # on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled - # if it's possible to compile MoE kernels that use its output. - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") - message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " - "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " - "if you intend on running FP8 quantized MoE models on Hopper.") - else() - message(STATUS "Not building grouped_mm_c3x as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() - - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") - message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " - "if you intend on running FP8 quantized MoE models on Blackwell.") - else() - message(STATUS "Not building grouped_mm_c3x as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() - # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") @@ -787,36 +611,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "in CUDA target architectures.") endif() - # moe_data.cu is used by all CUTLASS MoE kernels. - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - message(STATUS "Not building moe_data as CUDA Compiler version is " - "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " - "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") - else() - message(STATUS "Not building moe_data as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() - - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - # # Machete kernels @@ -964,7 +758,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY) # set(VLLM_STABLE_EXT_SRC - "csrc/libtorch_stable/torch_bindings.cpp") + "csrc/libtorch_stable/torch_bindings.cpp" + "csrc/cutlass_extensions/common.cpp" + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC @@ -979,6 +775,209 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${CUDA_ARCHS}") endif() + # + # CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch) + # + set(SCALED_MM_3X_ARCHS) + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.0 or later + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + + # The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.8 or later + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + + # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) + # require CUDA 12.8 or later + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # + # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) + # kernels for the remaining archs that are not already built for 3x. + # (Build 8.9 for FP8) + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + "7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}") + # subtract out the archs that are already built for 3x + list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) + if (SCALED_MM_2X_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") + message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") + else() + if (SCALED_MM_3X_ARCHS) + message(STATUS "Not building scaled_mm_c2x as all archs are already built" + " for and covered by scaled_mm_c3x") + else() + message(STATUS "Not building scaled_mm_c2x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # + # CUTLASS MoE kernels (moved from _C to _C_stable_libtorch) + # + + # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works + # on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled + # if it's possible to compile MoE kernels that use its output. + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + # moe_data.cu is used by all CUTLASS MoE kernels. + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + message(STATUS "Not building moe_data as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") + else() + message(STATUS "Not building moe_data as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + message(STATUS "Enabling C_stable extension.") define_extension_target( _C_stable_libtorch @@ -987,6 +986,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SOURCES ${VLLM_STABLE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) @@ -1000,6 +1000,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Needed to use cuda APIs from C-shim target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA) + + # Needed by CUTLASS kernels + target_compile_definitions(_C_stable_libtorch PRIVATE + CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) endif() # diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 91c215071..c32d3a060 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -6,14 +6,16 @@ #include #include +#include + /** * Helper function for checking CUTLASS errors */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - TORCH_CHECK(error == cutlass::Status::kSuccess, \ - cutlassGetStatusString(error)); \ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ } inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index cf79507e1..c2ddcea6d 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -3,6 +3,14 @@ #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" +// This header is shared by both _C (unstable ABI) and _C_stable_libtorch +// (stable ABI) targets. When compiled under the stable ABI target, +// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we +// use torch::stable::Tensor instead. +#ifdef TORCH_TARGET_VERSION + #include +#endif + /* This file defines custom epilogues for fusing channel scales, token scales, bias, and activation zero-points onto a GEMM operation using the @@ -15,6 +23,12 @@ namespace vllm::c3x { +#ifdef TORCH_TARGET_VERSION +using TensorType = torch::stable::Tensor; +#else +using TensorType = torch::Tensor; +#endif + using namespace cute; template @@ -84,7 +98,7 @@ struct ScaledEpilogueBase { // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. template - static auto args_from_tensor(torch::Tensor const& tensor) { + static auto args_from_tensor(TensorType const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = static_cast(tensor.data_ptr()); if constexpr (std::is_same_v> || @@ -100,7 +114,7 @@ struct ScaledEpilogueBase { // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template - static auto args_from_tensor(std::optional const& tensor) { + static auto args_from_tensor(std::optional const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; static_assert(std::is_same_v> || @@ -158,8 +172,8 @@ struct ScaledEpilogue cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); @@ -203,9 +217,9 @@ struct ScaledEpilogueBias cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp EVTComputeScaleB, Bias>; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - std::optional const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& azp_adj, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken EVTComputeScaleB, Bias>; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - std::optional const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& azp_adj, + TensorType const& azp, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp similarity index 91% rename from csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp rename to csrc/libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index ad8c0067d..f6737a73d 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" /* @@ -52,7 +54,7 @@ struct ScaledEpilogueBase { // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. template - static auto args_from_tensor(torch::Tensor const& tensor) { + static auto args_from_tensor(torch::stable::Tensor const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = static_cast(tensor.data_ptr()); if constexpr (std::is_same_v> || @@ -68,7 +70,8 @@ struct ScaledEpilogueBase { // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template - static auto args_from_tensor(std::optional const& tensor) { + static auto args_from_tensor( + std::optional const& tensor) { static_assert(std::is_same_v>); using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; @@ -117,8 +120,8 @@ struct ScaledEpilogue cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { + static ArgumentType prepare_args(torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); @@ -160,9 +163,9 @@ struct ScaledEpilogueBias using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { + static ArgumentType prepare_args(torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - std::optional const& bias) { + static ArgumentType prepare_args( + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& azp_adj, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - std::optional const& bias) { + static ArgumentType prepare_args( + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index b74c5c505..914f033b8 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -27,4 +27,61 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input, torch::stable::Tensor& output_s, int64_t group_size, double eps, double int8_min, double int8_max); + +bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); +bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); +bool cutlass_group_gemm_supported(int64_t cuda_device_capability); + +void cutlass_scaled_mm(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_moe_mm(torch::stable::Tensor& out_tensors, + torch::stable::Tensor const& a_tensors, + torch::stable::Tensor const& b_tensors, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& expert_offsets, + torch::stable::Tensor const& problem_sizes, + torch::stable::Tensor const& a_strides, + torch::stable::Tensor const& b_strides, + torch::stable::Tensor const& c_strides, bool per_act_token, + bool per_out_ch); + +void cutlass_scaled_mm_azp(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias); + +void get_cutlass_moe_mm_data( + const torch::stable::Tensor& topk_ids, + torch::stable::Tensor& expert_offsets, + torch::stable::Tensor& problem_sizes1, + torch::stable::Tensor& problem_sizes2, + torch::stable::Tensor& input_permutation, + torch::stable::Tensor& output_permutation, const int64_t num_experts, + const int64_t n, const int64_t k, + const std::optional& blockscale_offsets, + const bool is_gated); + +void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( + const torch::stable::Tensor& expert_first_token_offset, + torch::stable::Tensor& problem_sizes1, + torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k, + const bool swap_ab); + +void get_cutlass_batched_moe_mm_data( + torch::stable::Tensor& expert_offsets, + torch::stable::Tensor& problem_sizes1, + torch::stable::Tensor& problem_sizes2, + const torch::stable::Tensor& expert_num_tokens, + const int64_t num_local_experts, const int64_t padded_m, const int64_t n, + const int64_t k); #endif diff --git a/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh similarity index 84% rename from csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh index 26de32ce2..ae40c0989 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh @@ -2,9 +2,10 @@ // clang-format will break include orders // clang-format off -#include +#include +#include -#include +#include "libtorch_stable/torch_utils.h" #include "cutlass/cutlass.h" @@ -25,14 +26,14 @@ namespace vllm::c3x { static inline cute::Shape get_problem_shape( - torch::Tensor const& a, torch::Tensor const& b) { + torch::stable::Tensor const& a, torch::stable::Tensor const& b) { int32_t m = a.size(0), n = b.size(1), k = a.size(1); return {m, n, k, 1}; } template void cutlass_gemm_caller( - torch::Device device, cute::Shape prob_shape, + torch::stable::Device device, cute::Shape prob_shape, typename GemmKernel::MainloopArguments mainloop_args, typename GemmKernel::EpilogueArguments epilogue_args, typename GemmKernel::TileSchedulerArguments scheduler = {}) { @@ -50,19 +51,20 @@ void cutlass_gemm_caller( CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(device); - auto workspace = torch::empty(workspace_size, workspace_options); + auto workspace = + torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, + std::nullopt, device); - auto stream = at::cuda::getCurrentCUDAStream(device.index()); + auto stream = get_current_cuda_stream(device.index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, +void cutlass_gemm_caller(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh similarity index 100% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu similarity index 51% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu index 4cd38f497..bc088cf63 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu @@ -4,13 +4,12 @@ namespace vllm { -void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - std::optional const& azp, - std::optional const& bias) { +void cutlass_scaled_mm_azp_sm90_int8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias) { if (azp) { return cutlass_scaled_mm_sm90_int8_epilogue< c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu new file mode 100644 index 000000000..f3df69850 --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -0,0 +1,22 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm100_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { + if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half); + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh similarity index 93% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index db3b26c08..cf84fc3a6 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "cuda_utils.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise { }; template -void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { static constexpr bool swap_ab = Gemm::swap_ab; using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; @@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, } template -void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu new file mode 100644 index 000000000..7ceb0697d --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu @@ -0,0 +1,22 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm120_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { + if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + cutlass_gemm_blockwise_sm120_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half); + cutlass_gemm_blockwise_sm120_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh similarity index 91% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh index 0701cb586..14de9b1e4 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "cuda_utils.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 { }; template -void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; @@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, } template -void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { int M = a.size(0); if (M <= 256) { using Gemm = typename sm120_blockwise_fp8_config_M64::Gemm; diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu new file mode 100644 index 000000000..d3318c487 --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu @@ -0,0 +1,23 @@ + +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm90_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { + if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half); + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh similarity index 87% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index c40d49966..cf62e81fd 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise { }; template -void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; @@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, int32_t m = a.size(0), n = b.size(1), k = a.size(1); - TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); + STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); StrideA a_stride; StrideB b_stride; @@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, } template -void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { // TODO: better heuristics cutlass_gemm_caller_blockwise, diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp similarity index 55% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp index 2204a4925..adb3de50f 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp @@ -1,52 +1,57 @@ -#include +#include +#include #include "cuda_utils.h" #include "cutlass_extensions/common.hpp" template -void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - std::optional const& bias, +void dispatch_scaled_mm(torch::stable::Tensor& c, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias, Fp8Func fp8_func, Int8Func int8_func, BlockwiseFunc blockwise_func) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + STD_TORCH_CHECK(a_scales.scalar_type() == + torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b_scales.scalar_type() == + torch::headeronly::ScalarType::Float); int M = a.size(0), N = b.size(1), K = a.size(1); if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (a.dtype() == torch::kFloat8_e4m3fn) { + STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) { fp8_func(c, a, b, a_scales, b_scales, bias); } else { - TORCH_CHECK(a.dtype() == torch::kInt8); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char); if constexpr (!std::is_same_v) { int8_func(c, a, b, a_scales, b_scales, bias); } else { int32_t version_num = get_sm_version_num(); - TORCH_CHECK( + STD_TORCH_CHECK( false, "Int8 not supported on SM", version_num, ". Use FP8 quantization instead, or run on older arch (SM < 100)."); } } } else { - TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); - TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); + STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); + STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); int32_t version_num = get_sm_version_num(); if (version_num >= 90) { - TORCH_CHECK( + STD_TORCH_CHECK( a.size(0) == a_scales.size(0) && cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), "a_scale_group_shape must be [1, 128]."); - TORCH_CHECK( + STD_TORCH_CHECK( cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), "b_scale_group_shape must be [128, 128]."); } - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); blockwise_func(c, a, b, a_scales, b_scales); } } diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp new file mode 100644 index 000000000..a3a3eb3e2 --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include + +namespace vllm { + +void cutlass_scaled_mm_sm90_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_sm90_int8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_azp_sm90_int8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias); + +void cutlass_scaled_mm_blockwise_sm90_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales); + +void cutlass_scaled_mm_sm100_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_sm120_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_blockwise_sm100_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales); + +void cutlass_scaled_mm_blockwise_sm120_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales); +} // namespace vllm diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu new file mode 100644 index 000000000..e910103c4 --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm100_fp8_dispatch.cuh" + +namespace vllm { + +void cutlass_scaled_mm_sm100_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias) { + STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(), + "currently bias dtype must match output dtype ", + out.scalar_type()); + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales, *bias); + } else { + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh similarity index 90% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 311cd4bd4..46c36d13e 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "scaled_mm.cuh" #include "cutlass_gemm_caller.cuh" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" @@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab { }; template -void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, +void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, EpilogueArgs&&... epilogue_params) { static constexpr bool swap_ab = Gemm::swap_ab; using ElementAB = typename Gemm::ElementAB; @@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, template -inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - EpilogueArgs&&... args) { +inline void cutlass_gemm_sm100_fp8_dispatch( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) { static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + STD_TORCH_CHECK(a.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); + STD_TORCH_CHECK(b.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); using Cutlass3xGemmDefault = typename sm100_fp8_config_default -void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, +void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, EpilogueArgs&&... epilogue_args) { - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + STD_TORCH_CHECK(a.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); + STD_TORCH_CHECK(b.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); - if (out.dtype() == torch::kBFloat16) { + if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) { return cutlass_gemm_sm100_fp8_dispatch( out, a, b, a_scales, b_scales, std::forward(epilogue_args)...); } else { - TORCH_CHECK(out.dtype() == torch::kFloat16); + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half); return cutlass_gemm_sm100_fp8_dispatch( out, a, b, a_scales, b_scales, diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu new file mode 100644 index 000000000..fb84faa2a --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu @@ -0,0 +1,25 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm120_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_sm120_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias) { + STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(), + "currently bias dtype must match output dtype ", + out.scalar_type()); + return cutlass_scaled_mm_sm120_fp8_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm120_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh similarity index 88% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh index 37846a87b..245f5c10f 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "scaled_mm.cuh" #include "cutlass_gemm_caller.cuh" @@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 { template typename Epilogue, typename... EpilogueArgs> -inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, +inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, EpilogueArgs&&... args) { static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + STD_TORCH_CHECK(a.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); + STD_TORCH_CHECK(b.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); int M = a.size(0); @@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, template